Skip to content

Commit

Permalink
hybrid search (#85)
Browse files Browse the repository at this point in the history
Implement alongside vector search:

- hybrid search over metadata using jsonpath queries.
- hybrid search using message creation timestamps
  • Loading branch information
danielchalef authored Jun 3, 2023
1 parent 84328fd commit b3287e0
Show file tree
Hide file tree
Showing 9 changed files with 380 additions and 54 deletions.
4 changes: 2 additions & 2 deletions pkg/memorystore/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ func (pms *PostgresMemoryStore) SearchMemory(
ctx context.Context,
appState *models.AppState,
sessionID string,
query *models.SearchPayload,
query *models.MemorySearchPayload,
limit int,
) ([]models.SearchResult, error) {
) ([]models.MemorySearchResult, error) {
searchResults, err := searchMessages(ctx, appState, pms.Client, sessionID, query, limit)
return searchResults, err
}
Expand Down
202 changes: 172 additions & 30 deletions pkg/memorystore/postgres_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,77 +2,219 @@ package memorystore

import (
"context"
"encoding/json"
"errors"
"math"
"strings"

"github.com/sirupsen/logrus"

"github.com/getzep/zep/pkg/llms"
"github.com/getzep/zep/pkg/models"
"github.com/pgvector/pgvector-go"
"github.com/uptrace/bun"
)

const defaultSearchLimit = 10

type JSONQuery struct {
JSONPath string `json:"jsonpath"`
And []*JSONQuery `json:"and,omitempty"`
Or []*JSONQuery `json:"or,omitempty"`
}

func searchMessages(
ctx context.Context,
appState *models.AppState,
db *bun.DB,
sessionID string,
query *models.SearchPayload,
query *models.MemorySearchPayload,
limit int,
) ([]models.SearchResult, error) {
if query == nil {
return nil, NewStorageError("nil query received", nil)
) ([]models.MemorySearchResult, error) {
logrus.Debugf("searchMessages called for session %s", sessionID)

if query == nil || appState == nil {
return nil, NewStorageError("nil query or appState received", nil)
}

s := query.Text
if s == "" {
if query.Text == "" && len(query.Metadata) == 0 {
return nil, NewStorageError("empty query", errors.New("empty query"))
}

if appState == nil {
return nil, NewStorageError("nil appState received", nil)
dbQuery := buildDBSelectQuery(ctx, appState, db, query)
if len(query.Metadata) > 0 {
var err error
dbQuery, err = applyMetadataFilter(dbQuery, query.Metadata)
if err != nil {
return nil, NewStorageError("error applying metadata filter", err)
}
}

log.Debugf("searchMessages called for session %s", sessionID)
dbQuery = dbQuery.Where("m.session_id = ?", sessionID)

// Ensure we don't return deleted records.
dbQuery = dbQuery.Where("m.deleted_at IS NULL")

// Add sort and limit.
sortQuery(query.Text, dbQuery)

if limit == 0 {
limit = 10
limit = defaultSearchLimit
}
dbQuery = dbQuery.Limit(limit)

e, err := llms.EmbedMessages(ctx, appState, []string{s})
results, err := executeScan(ctx, dbQuery)
if err != nil {
return nil, NewStorageError("failed to embed query", err)
return nil, NewStorageError("memory searchMessages failed", err)
}
vector := pgvector.NewVector(e[0].Embedding)

var results []models.SearchResult
err = db.NewSelect().
TableExpr("message_embedding AS me").
filteredResults := filterValidResults(results, query.Metadata)
logrus.Debugf("searchMessages completed for session %s", sessionID)

return filteredResults, nil
}

func buildDBSelectQuery(
ctx context.Context,
appState *models.AppState,
db *bun.DB,
query *models.MemorySearchPayload,
) *bun.SelectQuery {
dbQuery := db.NewSelect().TableExpr("message_embedding AS me").
Join("JOIN message AS m").
JoinOn("me.message_uuid = m.uuid").
ColumnExpr("m.uuid AS message__uuid").
ColumnExpr("m.created_at AS message__created_at").
ColumnExpr("m.role AS message__role").
ColumnExpr("m.content AS message__content").
ColumnExpr("m.metadata AS message__metadata").
ColumnExpr("m.token_count AS message__token_count").
ColumnExpr("1 - (embedding <=> ? ) AS dist", vector).
Where("m.session_id = ?", sessionID).
Order("dist DESC").
Limit(limit).
Scan(ctx, &results)
if err != nil {
return nil, NewStorageError("memory searchMessages failed", err)
ColumnExpr("m.token_count AS message__token_count")

if query.Text != "" {
dbQuery, _ = addVectorColumn(ctx, appState, dbQuery, query.Text)
}

// some results may be returned where distance is NaN. This is a race between
// newly added messages and the search query. We filter these out.
var filteredResults []models.SearchResult
return dbQuery
}

func applyMetadataFilter(
dbQuery *bun.SelectQuery,
metadata map[string]interface{},
) (*bun.SelectQuery, error) {
qb := dbQuery.QueryBuilder()

if where, ok := metadata["where"]; ok {
j, err := json.Marshal(where)
if err != nil {
return nil, NewStorageError("error marshalling metadata", err)
}

var jq JSONQuery
err = json.Unmarshal(j, &jq)
if err != nil {
return nil, NewStorageError("error unmarshalling metadata", err)
}
qb = parseJSONQuery(qb, &jq, false)
}

addDateFilters(&qb, metadata)

dbQuery = qb.Unwrap().(*bun.SelectQuery)

return dbQuery, nil
}

func sortQuery(searchText string, dbQuery *bun.SelectQuery) {
if searchText != "" {
dbQuery.Order("dist DESC")
} else {
dbQuery.Order("m.created_at DESC")
}
}

func executeScan(
ctx context.Context,
dbQuery *bun.SelectQuery,
) ([]models.MemorySearchResult, error) {
var results []models.MemorySearchResult
err := dbQuery.Scan(ctx, &results)
return results, err
}

func filterValidResults(
results []models.MemorySearchResult,
metadata map[string]interface{},
) []models.MemorySearchResult {
var filteredResults []models.MemorySearchResult
for _, result := range results {
if !math.IsNaN(result.Dist) {
if !math.IsNaN(result.Dist) || len(metadata) > 0 {
filteredResults = append(filteredResults, result)
}
}
log.Debugf("searchMessages completed for session %s", sessionID)
return filteredResults
}

return filteredResults, nil
// addDateFilters adds date filters to the query
func addDateFilters(qb *bun.QueryBuilder, m map[string]interface{}) {
if startDate, ok := m["start_date"]; ok {
*qb = (*qb).Where("m.created_at >= ?", startDate)
}
if endDate, ok := m["end_date"]; ok {
*qb = (*qb).Where("m.created_at <= ?", endDate)
}
}

// addVectorColumn adds a column to the query that calculates the distance between the query text and the message embedding
func addVectorColumn(
ctx context.Context,
appState *models.AppState,
q *bun.SelectQuery,
queryText string,
) (*bun.SelectQuery, error) {
e, err := llms.EmbedMessages(ctx, appState, []string{queryText})
if err != nil {
return nil, NewStorageError("failed to embed query", err)
}

vector := pgvector.NewVector(e[0].Embedding)
return q.ColumnExpr("1 - (embedding <=> ? ) AS dist", vector), nil
}

// parseJSONQuery recursively parses a JSONQuery and returns a bun.QueryBuilder.
// TODO: fix the addition of extraneous parentheses in the query
func parseJSONQuery(qb bun.QueryBuilder, jq *JSONQuery, isOr bool) bun.QueryBuilder {
if jq.JSONPath != "" {
path := strings.ReplaceAll(jq.JSONPath, "'", "\"")
if isOr {
qb = qb.WhereOr(
"jsonb_path_exists(m.metadata, ?)",
path,
)
} else {
qb = qb.Where(
"jsonb_path_exists(m.metadata, ?)",
path,
)
}
}

if len(jq.And) > 0 {
qb = qb.WhereGroup(" AND ", func(qq bun.QueryBuilder) bun.QueryBuilder {
for _, subQuery := range jq.And {
qq = parseJSONQuery(qq, subQuery, false)
}
return qq
})
}

if len(jq.Or) > 0 {
qb = qb.WhereGroup(" AND ", func(qq bun.QueryBuilder) bun.QueryBuilder {
for _, subQuery := range jq.Or {
qq = parseJSONQuery(qq, subQuery, true)
}
return qq
})
}

return qb
}
Loading

0 comments on commit b3287e0

Please sign in to comment.