diff --git a/cmd/zep/run.go b/cmd/zep/run.go index 17d14726..ec4fcad0 100644 --- a/cmd/zep/run.go +++ b/cmd/zep/run.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "os" "os/signal" @@ -63,6 +64,7 @@ func NewAppState(cfg *config.Config) *models.AppState { initializeMemoryStore(appState) setupSignalHandler(appState) + setupPurgeProcessor(context.Background(), appState) return appState } @@ -123,3 +125,31 @@ func setupSignalHandler(appState *models.AppState) { os.Exit(0) }() } + +// setupPurgeProcessor sets up a go routine to purge deleted records from the MemoryStore +// at a regular interval. It's cancellable via the passed context. +// If Config.DataConfig.PurgeEvery is 0, this function does nothing. +func setupPurgeProcessor(ctx context.Context, appState *models.AppState) { + interval := time.Duration(appState.Config.DataConfig.PurgeEvery) * time.Minute + if interval == 0 { + log.Debug("purge delete processor disabled") + return + } + + log.Infof("Starting purge delete processor. Purging every %v", interval) + go func() { + for { + select { + case <-ctx.Done(): + log.Info("Stopping purge delete processor") + return + default: + err := appState.MemoryStore.PurgeDeleted(ctx) + if err != nil { + log.Errorf("error purging deleted records: %v", err) + } + } + time.Sleep(interval) + } + }() +} diff --git a/config.yaml b/config.yaml index ed12d0d0..c91c6558 100644 --- a/config.yaml +++ b/config.yaml @@ -26,6 +26,10 @@ memory_store: dsn: "postgres://postgres:postgres@localhost:5432/?sslmode=disable" server: port: 8000 +data: + # PurgeEvery is the period between hard deletes, in minutes. + # If set to 0 or undefined, hard deletes will not be performed. + purge_every: 60 log: level: "info" diff --git a/config/models.go b/config/models.go index 50db3b2e..ca0eb86a 100644 --- a/config/models.go +++ b/config/models.go @@ -10,6 +10,7 @@ type Config struct { MemoryStore MemoryStoreConfig `mapstructure:"memory_store"` Server ServerConfig `mapstructure:"server"` Log LogConfig `mapstructure:"log"` + DataConfig DataConfig `mapstructure:"data"` } type MemoryStoreConfig struct { @@ -45,6 +46,12 @@ type LogConfig struct { Level string `mapstructure:"level"` } +type DataConfig struct { + // PurgeEvery is the period between hard deletes, in minutes. + // If set to 0, hard deletes will not be performed. + PurgeEvery int `mapstructure:"purge_every"` +} + // ExtractorsConfig holds the configuration for all extractors type ExtractorsConfig struct { Summarizer SummarizerConfig `mapstructure:"summarizer"` diff --git a/pkg/memorystore/postgres.go b/pkg/memorystore/postgres.go index 1364a243..388e51b9 100644 --- a/pkg/memorystore/postgres.go +++ b/pkg/memorystore/postgres.go @@ -237,6 +237,15 @@ func (pms *PostgresMemoryStore) GetMessageVectors(ctx context.Context, return embeddings, nil } +func (pms *PostgresMemoryStore) PurgeDeleted(ctx context.Context) error { + err := purgeDeleted(ctx, pms.Client) + if err != nil { + return NewStorageError("failed to purge deleted", err) + } + + return nil +} + // acquireAdvisoryLock acquires a PostgreSQL advisory lock for the given key. // Expects a transaction to be open in tx. // `pg_advisory_xact_lock` will wait until the lock is available. The lock is released diff --git a/pkg/memorystore/postgres_delete.go b/pkg/memorystore/postgres_delete.go index ee3e1577..894613be 100644 --- a/pkg/memorystore/postgres_delete.go +++ b/pkg/memorystore/postgres_delete.go @@ -8,17 +8,13 @@ import ( ) // deleteSession deletes a session from the memory store. This is a soft delete. -// TODO: This is ugly. Determine why bun's cascading deletes aren't working +// Note: soft_deletes don't trigger cascade deletes, so we need to delete all +// related records manually. func deleteSession(ctx context.Context, db *bun.DB, sessionID string) error { log.Debugf("deleting from memory store for session %s", sessionID) - schemas := []bun.BeforeCreateTableHook{ - &PgMessageVectorStore{}, - &PgSummaryStore{}, - &PgMessageStore{}, - &PgSession{}, - } - for _, schema := range schemas { - log.Debugf("deleting session %s from schema %v", sessionID, schema) + + for _, schema := range tableList { + log.Debugf("deleting session %s from schema %T", sessionID, schema) _, err := db.NewDelete(). Model(schema). Where("session_id = ?", sessionID). @@ -31,3 +27,23 @@ func deleteSession(ctx context.Context, db *bun.DB, sessionID string) error { return nil } + +// purgeDeleted hard deletes all soft deleted records from the memory store. +func purgeDeleted(ctx context.Context, db *bun.DB) error { + log.Debugf("purging memory store") + + for _, schema := range tableList { + log.Debugf("purging schema %T", schema) + _, err := db.NewDelete(). + Model(schema). + WhereDeleted(). + ForceDelete(). + Exec(ctx) + if err != nil { + return fmt.Errorf("error purging rows from %T: %w", schema, err) + } + } + log.Info("completed purging memory store") + + return nil +} diff --git a/pkg/memorystore/postgres_delete_test.go b/pkg/memorystore/postgres_delete_test.go new file mode 100644 index 00000000..ba2d1774 --- /dev/null +++ b/pkg/memorystore/postgres_delete_test.go @@ -0,0 +1,107 @@ +package memorystore + +import ( + "context" + "testing" + + "github.com/uptrace/bun" + + "github.com/getzep/zep/pkg/models" + "github.com/getzep/zep/pkg/testutils" + "github.com/stretchr/testify/assert" +) + +func setupTestDeleteData(ctx context.Context, testDB *bun.DB) (string, error) { + // Test data + sessionID, err := testutils.GenerateRandomSessionID(16) + if err != nil { + return "", err + } + + _, err = putSession(ctx, testDB, sessionID, map[string]interface{}{}) + if err != nil { + return "", err + } + + messages := []models.Message{ + { + Role: "user", + Content: "Hello", + Metadata: map[string]interface{}{"timestamp": "1629462540"}, + }, + { + Role: "bot", + Content: "Hi there!", + Metadata: map[string]interface{}{"timestamp": 1629462551}, + }, + } + + // Call putMessages function + resultMessages, err := putMessages(ctx, testDB, sessionID, messages) + if err != nil { + return "", err + } + + summary := models.Summary{ + Content: "This is a summary", + Metadata: map[string]interface{}{ + "timestamp": 1629462551, + }, + SummaryPointUUID: resultMessages[0].UUID, + } + _, err = putSummary(ctx, testDB, sessionID, &summary) + if err != nil { + return "", err + } + + return sessionID, nil +} + +func TestDeleteSession(t *testing.T) { + memoryWindow := 10 + appState.Config.Memory.MessageWindow = memoryWindow + + sessionID, err := setupTestDeleteData(testCtx, testDB) + assert.NoError(t, err, "setupTestDeleteData should not return an error") + + err = deleteSession(testCtx, testDB, sessionID) + assert.NoError(t, err, "deleteSession should not return an error") + + // Test that session is deleted + resp, err := getSession(testCtx, testDB, sessionID) + assert.NoError(t, err, "getSession should not return an error") + assert.Nil(t, resp, "getSession should return nil") + + // Test that messages are deleted + respMessages, err := getMessages(testCtx, testDB, sessionID, memoryWindow, nil, 0) + assert.NoError(t, err, "getMessages should not return an error") + assert.Nil(t, respMessages, "getMessages should return nil") + + // Test that summary is deleted + respSummary, err := getSummary(testCtx, testDB, sessionID) + assert.NoError(t, err, "getSummary should not return an error") + assert.Nil(t, respSummary, "getSummary should return nil") +} + +func TestPurgeDeleted(t *testing.T) { + sessionID, err := setupTestDeleteData(testCtx, testDB) + assert.NoError(t, err, "setupTestDeleteData should not return an error") + + err = deleteSession(testCtx, testDB, sessionID) + assert.NoError(t, err, "deleteSession should not return an error") + + err = purgeDeleted(testCtx, testDB) + assert.NoError(t, err, "purgeDeleted should not return an error") + + // Test that session is deleted + for _, schema := range tableList { + r, err := testDB.NewSelect(). + Model(schema). + WhereDeleted(). + Exec(testCtx) + assert.NoError(t, err, "purgeDeleted should not return an error") + rows, err := r.RowsAffected() + assert.NoError(t, err, "RowsAffected should not return an error") + assert.True(t, rows == 0, "purgeDeleted should delete all rows") + } +} diff --git a/pkg/memorystore/postgres_schema.go b/pkg/memorystore/postgres_schema.go index 866cac76..f657cef4 100644 --- a/pkg/memorystore/postgres_schema.go +++ b/pkg/memorystore/postgres_schema.go @@ -172,6 +172,13 @@ func (*PgSummaryStore) AfterCreateTable( return err } +var tableList = []bun.BeforeCreateTableHook{ + &PgMessageVectorStore{}, + &PgSummaryStore{}, + &PgMessageStore{}, + &PgSession{}, +} + // ensurePostgresSetup creates the db schema if it does not exist. func ensurePostgresSetup( ctx context.Context, @@ -183,13 +190,9 @@ func ensurePostgresSetup( return fmt.Errorf("error creating pgvector extension: %w", err) } - schemas := []bun.BeforeCreateTableHook{ - &PgSession{}, - &PgMessageStore{}, - &PgMessageVectorStore{}, - &PgSummaryStore{}, - } - for _, schema := range schemas { + // iterate through tableList in reverse order to create tables with foreign keys first + for i := len(tableList) - 1; i >= 0; i-- { + schema := tableList[i] _, err := db.NewCreateTable(). Model(schema). IfNotExists(). diff --git a/pkg/memorystore/postgres_test.go b/pkg/memorystore/postgres_test.go index 4e6ccda6..b4184c7f 100644 --- a/pkg/memorystore/postgres_test.go +++ b/pkg/memorystore/postgres_test.go @@ -197,64 +197,6 @@ func TestGetSession(t *testing.T) { } } -func TestPgDeleteSession(t *testing.T) { - memoryWindow := 10 - appState.Config.Memory.MessageWindow = memoryWindow - - // Test data - sessionID, err := testutils.GenerateRandomSessionID(16) - assert.NoError(t, err, "GenerateRandomSessionID should not return an error") - - _, err = putSession(testCtx, testDB, sessionID, map[string]interface{}{}) - assert.NoError(t, err, "putSession should not return an error") - - messages := []models.Message{ - { - Role: "user", - Content: "Hello", - Metadata: map[string]interface{}{"timestamp": "1629462540"}, - }, - { - Role: "bot", - Content: "Hi there!", - Metadata: map[string]interface{}{"timestamp": 1629462551}, - }, - } - - // Call putMessages function - resultMessages, err := putMessages(testCtx, testDB, sessionID, messages) - assert.NoError(t, err, "putMessages should not return an error") - - // Put a summary - summary := models.Summary{ - Content: "This is a summary", - Metadata: map[string]interface{}{ - "timestamp": 1629462551, - }, - SummaryPointUUID: resultMessages[0].UUID, - } - _, err = putSummary(testCtx, testDB, sessionID, &summary) - assert.NoError(t, err, "putSummary should not return an error") - - err = deleteSession(testCtx, testDB, sessionID) - assert.NoError(t, err, "deleteSession should not return an error") - - // Test that session is deleted - resp, err := getSession(testCtx, testDB, sessionID) - assert.NoError(t, err, "getSession should not return an error") - assert.Nil(t, resp, "getSession should return nil") - - // Test that messages are deleted - respMessages, err := getMessages(testCtx, testDB, sessionID, memoryWindow, nil, 0) - assert.NoError(t, err, "getMessages should not return an error") - assert.Nil(t, respMessages, "getMessages should return nil") - - // Test that summary is deleted - respSummary, err := getSummary(testCtx, testDB, sessionID) - assert.NoError(t, err, "getSummary should not return an error") - assert.Nil(t, respSummary, "getSummary should return nil") -} - func TestPutMessages(t *testing.T) { messages := []models.Message{ { diff --git a/pkg/models/memorystore.go b/pkg/models/memorystore.go index 712e91ff..7b64d742 100644 --- a/pkg/models/memorystore.go +++ b/pkg/models/memorystore.go @@ -73,6 +73,8 @@ type MemoryStore[T any] interface { appState *AppState, eventData *MessageEvent, ) + // PurgeDeleted hard deletes all deleted data in the MemoryStore. + PurgeDeleted(ctx context.Context) error // Close is called when the application is shutting down. This is a good place to clean up any resources used by // the MemoryStore implementation. Close() error