diff --git a/pkg/memorystore/postgres.go b/pkg/memorystore/postgres.go index 8455e15a..d4d4f1fa 100644 --- a/pkg/memorystore/postgres.go +++ b/pkg/memorystore/postgres.go @@ -185,9 +185,9 @@ func (pms *PostgresMemoryStore) SearchMemory( ctx context.Context, appState *models.AppState, sessionID string, - query *models.SearchPayload, + query *models.MemorySearchPayload, limit int, -) ([]models.SearchResult, error) { +) ([]models.MemorySearchResult, error) { searchResults, err := searchMessages(ctx, appState, pms.Client, sessionID, query, limit) return searchResults, err } diff --git a/pkg/memorystore/postgres_search.go b/pkg/memorystore/postgres_search.go index 64db3432..8c22dccb 100644 --- a/pkg/memorystore/postgres_search.go +++ b/pkg/memorystore/postgres_search.go @@ -2,8 +2,12 @@ package memorystore import ( "context" + "encoding/json" "errors" "math" + "strings" + + "github.com/sirupsen/logrus" "github.com/getzep/zep/pkg/llms" "github.com/getzep/zep/pkg/models" @@ -11,42 +15,72 @@ import ( "github.com/uptrace/bun" ) +const defaultSearchLimit = 10 + +type JSONQuery struct { + JSONPath string `json:"jsonpath"` + And []*JSONQuery `json:"and,omitempty"` + Or []*JSONQuery `json:"or,omitempty"` +} + func searchMessages( ctx context.Context, appState *models.AppState, db *bun.DB, sessionID string, - query *models.SearchPayload, + query *models.MemorySearchPayload, limit int, -) ([]models.SearchResult, error) { - if query == nil { - return nil, NewStorageError("nil query received", nil) +) ([]models.MemorySearchResult, error) { + logrus.Debugf("searchMessages called for session %s", sessionID) + + if query == nil || appState == nil { + return nil, NewStorageError("nil query or appState received", nil) } - s := query.Text - if s == "" { + if query.Text == "" && len(query.Metadata) == 0 { return nil, NewStorageError("empty query", errors.New("empty query")) } - if appState == nil { - return nil, NewStorageError("nil appState received", nil) + dbQuery := buildDBSelectQuery(ctx, appState, db, query) + if len(query.Metadata) > 0 { + var err error + dbQuery, err = applyMetadataFilter(dbQuery, query.Metadata) + if err != nil { + return nil, NewStorageError("error applying metadata filter", err) + } } - log.Debugf("searchMessages called for session %s", sessionID) + dbQuery = dbQuery.Where("m.session_id = ?", sessionID) + + // Ensure we don't return deleted records. + dbQuery = dbQuery.Where("m.deleted_at IS NULL") + + // Add sort and limit. + sortQuery(query.Text, dbQuery) if limit == 0 { - limit = 10 + limit = defaultSearchLimit } + dbQuery = dbQuery.Limit(limit) - e, err := llms.EmbedMessages(ctx, appState, []string{s}) + results, err := executeScan(ctx, dbQuery) if err != nil { - return nil, NewStorageError("failed to embed query", err) + return nil, NewStorageError("memory searchMessages failed", err) } - vector := pgvector.NewVector(e[0].Embedding) - var results []models.SearchResult - err = db.NewSelect(). - TableExpr("message_embedding AS me"). + filteredResults := filterValidResults(results, query.Metadata) + logrus.Debugf("searchMessages completed for session %s", sessionID) + + return filteredResults, nil +} + +func buildDBSelectQuery( + ctx context.Context, + appState *models.AppState, + db *bun.DB, + query *models.MemorySearchPayload, +) *bun.SelectQuery { + dbQuery := db.NewSelect().TableExpr("message_embedding AS me"). Join("JOIN message AS m"). JoinOn("me.message_uuid = m.uuid"). ColumnExpr("m.uuid AS message__uuid"). @@ -54,25 +88,133 @@ func searchMessages( ColumnExpr("m.role AS message__role"). ColumnExpr("m.content AS message__content"). ColumnExpr("m.metadata AS message__metadata"). - ColumnExpr("m.token_count AS message__token_count"). - ColumnExpr("1 - (embedding <=> ? ) AS dist", vector). - Where("m.session_id = ?", sessionID). - Order("dist DESC"). - Limit(limit). - Scan(ctx, &results) - if err != nil { - return nil, NewStorageError("memory searchMessages failed", err) + ColumnExpr("m.token_count AS message__token_count") + + if query.Text != "" { + dbQuery, _ = addVectorColumn(ctx, appState, dbQuery, query.Text) } - // some results may be returned where distance is NaN. This is a race between - // newly added messages and the search query. We filter these out. - var filteredResults []models.SearchResult + return dbQuery +} + +func applyMetadataFilter( + dbQuery *bun.SelectQuery, + metadata map[string]interface{}, +) (*bun.SelectQuery, error) { + qb := dbQuery.QueryBuilder() + + if where, ok := metadata["where"]; ok { + j, err := json.Marshal(where) + if err != nil { + return nil, NewStorageError("error marshalling metadata", err) + } + + var jq JSONQuery + err = json.Unmarshal(j, &jq) + if err != nil { + return nil, NewStorageError("error unmarshalling metadata", err) + } + qb = parseJSONQuery(qb, &jq, false) + } + + addDateFilters(&qb, metadata) + + dbQuery = qb.Unwrap().(*bun.SelectQuery) + + return dbQuery, nil +} + +func sortQuery(searchText string, dbQuery *bun.SelectQuery) { + if searchText != "" { + dbQuery.Order("dist DESC") + } else { + dbQuery.Order("m.created_at DESC") + } +} + +func executeScan( + ctx context.Context, + dbQuery *bun.SelectQuery, +) ([]models.MemorySearchResult, error) { + var results []models.MemorySearchResult + err := dbQuery.Scan(ctx, &results) + return results, err +} + +func filterValidResults( + results []models.MemorySearchResult, + metadata map[string]interface{}, +) []models.MemorySearchResult { + var filteredResults []models.MemorySearchResult for _, result := range results { - if !math.IsNaN(result.Dist) { + if !math.IsNaN(result.Dist) || len(metadata) > 0 { filteredResults = append(filteredResults, result) } } - log.Debugf("searchMessages completed for session %s", sessionID) + return filteredResults +} - return filteredResults, nil +// addDateFilters adds date filters to the query +func addDateFilters(qb *bun.QueryBuilder, m map[string]interface{}) { + if startDate, ok := m["start_date"]; ok { + *qb = (*qb).Where("m.created_at >= ?", startDate) + } + if endDate, ok := m["end_date"]; ok { + *qb = (*qb).Where("m.created_at <= ?", endDate) + } +} + +// addVectorColumn adds a column to the query that calculates the distance between the query text and the message embedding +func addVectorColumn( + ctx context.Context, + appState *models.AppState, + q *bun.SelectQuery, + queryText string, +) (*bun.SelectQuery, error) { + e, err := llms.EmbedMessages(ctx, appState, []string{queryText}) + if err != nil { + return nil, NewStorageError("failed to embed query", err) + } + + vector := pgvector.NewVector(e[0].Embedding) + return q.ColumnExpr("1 - (embedding <=> ? ) AS dist", vector), nil +} + +// parseJSONQuery recursively parses a JSONQuery and returns a bun.QueryBuilder. +// TODO: fix the addition of extraneous parentheses in the query +func parseJSONQuery(qb bun.QueryBuilder, jq *JSONQuery, isOr bool) bun.QueryBuilder { + if jq.JSONPath != "" { + path := strings.ReplaceAll(jq.JSONPath, "'", "\"") + if isOr { + qb = qb.WhereOr( + "jsonb_path_exists(m.metadata, ?)", + path, + ) + } else { + qb = qb.Where( + "jsonb_path_exists(m.metadata, ?)", + path, + ) + } + } + + if len(jq.And) > 0 { + qb = qb.WhereGroup(" AND ", func(qq bun.QueryBuilder) bun.QueryBuilder { + for _, subQuery := range jq.And { + qq = parseJSONQuery(qq, subQuery, false) + } + return qq + }) + } + + if len(jq.Or) > 0 { + qb = qb.WhereGroup(" AND ", func(qq bun.QueryBuilder) bun.QueryBuilder { + for _, subQuery := range jq.Or { + qq = parseJSONQuery(qq, subQuery, true) + } + return qq + }) + } + + return qb } diff --git a/pkg/memorystore/postgres_search_test.go b/pkg/memorystore/postgres_search_test.go new file mode 100644 index 00000000..2203b06c --- /dev/null +++ b/pkg/memorystore/postgres_search_test.go @@ -0,0 +1,187 @@ +package memorystore + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "testing" + "time" + + "github.com/uptrace/bun" + + "github.com/getzep/zep/pkg/models" + "github.com/getzep/zep/pkg/testutils" + "github.com/stretchr/testify/assert" +) + +func TestVectorSearch(t *testing.T) { + // Test data + sessionID, err := testutils.GenerateRandomSessionID(16) + assert.NoError(t, err, "GenerateRandomSessionID should not return an error") + + // Call putMessages function + msgs, err := putMessages(testCtx, testDB, sessionID, testutils.TestMessages) + assert.NoError(t, err, "putMessages should not return an error") + + appState.MemoryStore.NotifyExtractors( + context.Background(), + appState, + &models.MessageEvent{SessionID: sessionID, + Messages: msgs}, + ) + + // enrichment runs async. Wait for it to finish + // This is hacky but I'd prefer not to add a WaitGroup to the putMessages function just for testing purposes + time.Sleep(time.Second * 2) + + // Test cases + testCases := []struct { + name string + query string + limit int + expectedErrorText string + }{ + {"Empty Query", "", 0, "empty query"}, + {"Non-empty Query", "travel", 0, ""}, + {"Limit 0", "travel", 0, ""}, + {"Limit 5", "travel", 5, ""}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := models.MemorySearchPayload{Text: tc.query} + expectedLastN := tc.limit + if expectedLastN == 0 { + expectedLastN = 10 // Default value + } + + s, err := searchMessages(testCtx, appState, testDB, sessionID, &q, expectedLastN) + + if tc.expectedErrorText != "" { + assert.ErrorContains( + t, + err, + tc.expectedErrorText, + "searchMessages should return the expected error", + ) + } else { + assert.NoError(t, err, "searchMessages should not return an error") + assert.Len(t, s, expectedLastN, fmt.Sprintf("Expected %d messages to be returned", expectedLastN)) + + for _, res := range s { + assert.NotNil(t, res.Message.UUID, "message__uuid should be present") + assert.NotNil(t, res.Message.CreatedAt, "message__created_at should be present") + assert.NotNil(t, res.Message.Role, "message__role should be present") + assert.NotNil(t, res.Message.Content, "message__content should be present") + assert.NotZero(t, res.Message.TokenCount, "message_token_count should be present") + } + } + }) + } +} + +func TestParseJSONQuery(t *testing.T) { + tests := []struct { + name string + jsonQuery string + expectedCond string + }{ + { + name: "Test 1", + jsonQuery: `{"where": {"jsonpath": "$.system.entities[*] ? (@.Label == \"DATE\")"}}`, + expectedCond: `WHERE (jsonb_path_exists(m.metadata, '$.system.entities[*] ? (@.Label == "DATE")'))`, + }, + { + name: "Test 2", + jsonQuery: `{"where": {"or": [{"jsonpath": "$.system.entities[*] ? (@.Label == \"DATE\")"},{"jsonpath": "$.system.entities[*] ? (@.Label == \"ORG\")"}]}}`, + expectedCond: `WHERE ((jsonb_path_exists(m.metadata, '$.system.entities[*] ? (@.Label == "DATE")')) OR (jsonb_path_exists(m.metadata, '$.system.entities[*] ? (@.Label == "ORG")')))`, + }, + { + name: "Test 3", + jsonQuery: `{"where": {"and": [{"jsonpath": "$.system.entities[*] ? (@.Label == \"DATE\")"},{"jsonpath": "$.system.entities[*] ? (@.Label == \"ORG\")"},{"or": [{"jsonpath": "$.system.entities[*] ? (@.Name == \"Iceland\")"},{"jsonpath": "$.system.entities[*] ? (@.Name == \"Canada\")"}]}]}}`, + expectedCond: `WHERE ((jsonb_path_exists(m.metadata, '$.system.entities[*] ? (@.Label == "DATE")')) AND (jsonb_path_exists(m.metadata, '$.system.entities[*] ? (@.Label == "ORG")')) AND ((jsonb_path_exists(m.metadata, '$.system.entities[*] ? (@.Name == "Iceland")')) OR (jsonb_path_exists(m.metadata, '$.system.entities[*] ? (@.Name == "Canada")'))))`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + qb := testDB.NewSelect(). + Model(&[]models.MemorySearchResult{}). + QueryBuilder() + + var metadata map[string]interface{} + err := json.Unmarshal([]byte(tt.jsonQuery), &metadata) + assert.NoError(t, err) + + query, err := json.Marshal(metadata["where"]) + assert.NoError(t, err) + + var jsonQuery JSONQuery + err = json.Unmarshal(query, &jsonQuery) + assert.NoError(t, err) + + qb = parseJSONQuery(qb, &jsonQuery, false) + + selectQuery := qb.Unwrap().(*bun.SelectQuery) + + // Extract the WHERE conditions from the SQL query + sql := selectQuery.String() + whereIndex := strings.Index(sql, "WHERE") + assert.True(t, whereIndex > 0, "WHERE clause should be present") + cond := sql[whereIndex:] + + // We use assert.Equal to test if the conditions are built correctly. + assert.Equal(t, tt.expectedCond, cond) + }) + } +} + +func TestAddDateFilters(t *testing.T) { + tests := []struct { + name string + inputDates string + expectedCond string + }{ + { + name: "Test 1 - Start Date only", + inputDates: `{"start_date": "2022-01-01"}`, + expectedCond: `WHERE (m.created_at >= '2022-01-01')`, + }, + { + name: "Test 2 - End Date only", + inputDates: `{"end_date": "2022-01-31"}`, + expectedCond: `WHERE (m.created_at <= '2022-01-31')`, + }, + { + name: "Test 3 - Start and End Dates", + inputDates: `{"start_date": "2022-01-01", "end_date": "2022-01-31"}`, + expectedCond: `WHERE (m.created_at >= '2022-01-01') AND (m.created_at <= '2022-01-31')`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + qb := testDB.NewSelect(). + Model(&[]models.MemorySearchResult{}). + QueryBuilder() + + var inputDates map[string]interface{} + err := json.Unmarshal([]byte(tt.inputDates), &inputDates) + assert.NoError(t, err) + + addDateFilters(&qb, inputDates) + + selectQuery := qb.Unwrap().(*bun.SelectQuery) + + // Extract the WHERE conditions from the SQL query + sql := selectQuery.String() + whereIndex := strings.Index(sql, "WHERE") + assert.True(t, whereIndex > 0, "WHERE clause should be present") + cond := sql[whereIndex:] + + // We use assert.Equal to test if the conditions are built correctly. + assert.Equal(t, tt.expectedCond, cond) + }) + } +} diff --git a/pkg/memorystore/postgres_test.go b/pkg/memorystore/postgres_test.go index 6ee0e426..17640f82 100644 --- a/pkg/memorystore/postgres_test.go +++ b/pkg/memorystore/postgres_test.go @@ -763,7 +763,7 @@ func TestSearch(t *testing.T) { // enrichment runs async. Wait for it to finish // This is hacky but I'd prefer not to add a WaitGroup to the putMessages function just for testing purposes - time.Sleep(time.Second * 4) + time.Sleep(time.Second * 2) // Test cases testCases := []struct { @@ -780,7 +780,7 @@ func TestSearch(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - q := models.SearchPayload{Text: tc.query} + q := models.MemorySearchPayload{Text: tc.query} expectedLastN := tc.limit if expectedLastN == 0 { expectedLastN = 10 // Default value diff --git a/pkg/models/memorystore.go b/pkg/models/memorystore.go index 184d2c64..d144bf4c 100644 --- a/pkg/models/memorystore.go +++ b/pkg/models/memorystore.go @@ -50,15 +50,14 @@ type MemoryStore[T any] interface { GetMessageVectors(ctx context.Context, appState *AppState, sessionID string) ([]Embeddings, error) - // SearchMemory retrieves a collection of SearchResults for a given sessionID and query. Currently, the query - // is a simple string, but this could be extended to support more complex queries in the future. The SearchResult - // structure can include both Messages and Summaries. Currently, we only search Messages. + // SearchMemory retrieves a collection of SearchResults for a given sessionID and query. Currently, The + // MemorySearchResult structure can include both Messages and Summaries. Currently, we only search Messages. SearchMemory( ctx context.Context, appState *AppState, sessionID string, - query *SearchPayload, - limit int) ([]SearchResult, error) + query *MemorySearchPayload, + limit int) ([]MemorySearchResult, error) // DeleteSession deletes all records for a given sessionID. This is a soft delete. Hard deletes will be handled // by a separate process or left to the implementation. DeleteSession(ctx context.Context, sessionID string) error diff --git a/pkg/models/search.go b/pkg/models/search.go index b565dd11..2170143a 100644 --- a/pkg/models/search.go +++ b/pkg/models/search.go @@ -1,13 +1,13 @@ package models -type SearchResult struct { +type MemorySearchResult struct { Message *Message `json:"message"` Summary *Summary `json:"summary"` // reserved for future use - Metadata map[string]interface{} `json:"meta,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` Dist float64 `json:"dist"` } -type SearchPayload struct { +type MemorySearchPayload struct { Text string `json:"text"` - Metadata map[string]interface{} `json:"meta,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` } diff --git a/pkg/server/handlers.go b/pkg/server/handlers.go index da9aee8f..df4ab688 100644 --- a/pkg/server/handlers.go +++ b/pkg/server/handlers.go @@ -117,7 +117,7 @@ func DeleteMemoryHandler(appState *models.AppState) http.HandlerFunc { } } -// RunSearchHandler godoc +// SearchMemoryHandler godoc // // @Summary Search memory messages for a given session // @Description search memory messages by session id and query @@ -126,15 +126,15 @@ func DeleteMemoryHandler(appState *models.AppState) http.HandlerFunc { // @Produce json // @Param session_id path string true "Session ID" // @Param limit query integer false "Limit the number of results returned" -// @Param searchPayload body models.SearchPayload true "Search query" -// @Success 200 {object} []models.SearchResult +// @Param searchPayload body models.MemorySearchPayload true "Search query" +// @Success 200 {object} []models.MemorySearchResult // @Failure 404 {object} APIError "Not Found" // @Failure 500 {object} APIError "Internal Server Error" // @Router /api/v1/sessions/{sessionId}/search [post] -func RunSearchHandler(appState *models.AppState) http.HandlerFunc { +func SearchMemoryHandler(appState *models.AppState) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { sessionID := chi.URLParam(r, "sessionId") - var payload models.SearchPayload + var payload models.MemorySearchPayload if err := decodeJSON(r, &payload); err != nil { renderError(w, err, http.StatusBadRequest) return diff --git a/pkg/server/routes.go b/pkg/server/routes.go index c7c263a5..ffa18e73 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -27,10 +27,9 @@ func Create(appState *models.AppState) *http.Server { } } -// @title Zep REST API -// @license.name Apache 2.0 -// @license.url http://www.apache.org/licenses/LICENSE-2.0.html - +// @title Zep REST API +// @license.name Apache 2.0 +// @license.url http://www.apache.org/licenses/LICENSE-2.0.html // @BasePath /apt/v1 // @schemes http https func setupRouter(appState *models.AppState) *chi.Mux { @@ -51,7 +50,7 @@ func setupRouter(appState *models.AppState) *chi.Mux { }) // Search-related routes r.Route("/search", func(r chi.Router) { - r.Post("/", RunSearchHandler(appState)) + r.Post("/", SearchMemoryHandler(appState)) }) }) }) diff --git a/pkg/testutils/utils.go b/pkg/testutils/utils.go index d7f133c1..6cda04c3 100644 --- a/pkg/testutils/utils.go +++ b/pkg/testutils/utils.go @@ -1,5 +1,4 @@ //go:build testutils -// +build testutils package testutils