diff --git a/chains/chains.go b/chains/chains.go index 156498403..17dd3aff0 100644 --- a/chains/chains.go +++ b/chains/chains.go @@ -26,13 +26,13 @@ type Chain interface { } // Call is the standard function used for executing chains. -func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...ChainCallOption) (map[string]any, error) { //nolint: lll +func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...ChainCallOption) (map[string]any, error) { // nolint: lll fullValues := make(map[string]any, 0) for key, value := range inputValues { fullValues[key] = value } - newValues, err := c.GetMemory().LoadMemoryVariables(inputValues) + newValues, err := c.GetMemory().LoadMemoryVariables(ctx, inputValues) if err != nil { return nil, err } @@ -53,7 +53,7 @@ func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...C return nil, err } - err = c.GetMemory().SaveContext(inputValues, outputValues) + err = c.GetMemory().SaveContext(ctx, inputValues, outputValues) if err != nil { return nil, err } @@ -65,7 +65,7 @@ func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...C // string output. func Run(ctx context.Context, c Chain, input any, options ...ChainCallOption) (string, error) { inputKeys := c.GetInputKeys() - memoryKeys := c.GetMemory().MemoryVariables() + memoryKeys := c.GetMemory().MemoryVariables(ctx) neededKeys := make([]string, 0, len(inputKeys)) // Remove keys gotten from the memory. diff --git a/chains/conversational_retrieval_qa.go b/chains/conversational_retrieval_qa.go index f154fc37b..ac6fc6110 100644 --- a/chains/conversational_retrieval_qa.go +++ b/chains/conversational_retrieval_qa.go @@ -88,14 +88,14 @@ func NewConversationalRetrievalQAFromLLM( // Call gets question, and relevant documents by question from the retriever and gives them to the combine // documents chain. -func (c ConversationalRetrievalQA) Call(ctx context.Context, values map[string]any, options ...ChainCallOption) (map[string]any, error) { //nolint: lll +func (c ConversationalRetrievalQA) Call(ctx context.Context, values map[string]any, options ...ChainCallOption) (map[string]any, error) { // nolint: lll query, ok := values[c.InputKey].(string) if !ok { return nil, fmt.Errorf("%w: %w", ErrInvalidInputValues, ErrInputValuesWrongType) } - chatHistoryStr, ok := values[c.Memory.GetMemoryKey()].(string) + chatHistoryStr, ok := values[c.Memory.GetMemoryKey(ctx)].(string) if !ok { - chatHistory, ok := values[c.Memory.GetMemoryKey()].([]schema.ChatMessage) + chatHistory, ok := values[c.Memory.GetMemoryKey(ctx)].([]schema.ChatMessage) if !ok { return nil, fmt.Errorf("%w: %w", ErrMissingMemoryKeyValues, ErrMemoryValuesWrongType) } diff --git a/chains/sequential.go b/chains/sequential.go index 6ce7e1860..a9ea2eef1 100644 --- a/chains/sequential.go +++ b/chains/sequential.go @@ -45,7 +45,7 @@ func (c *SequentialChain) validateSeqChain() error { knownKeys := util.ToSet(c.inputKeys) // Make sure memory keys don't collide with input keys - memoryKeys := c.memory.MemoryVariables() + memoryKeys := c.memory.MemoryVariables(context.Background()) overlappingKeys := util.Intersection(memoryKeys, knownKeys) if len(overlappingKeys) > 0 { return fmt.Errorf( diff --git a/memory/buffer.go b/memory/buffer.go index 54f63f0d4..e48732986 100644 --- a/memory/buffer.go +++ b/memory/buffer.go @@ -1,6 +1,7 @@ package memory import ( + "context" "errors" "fmt" @@ -31,7 +32,7 @@ func NewConversationBuffer(options ...ConversationBufferOption) *ConversationBuf } // MemoryVariables gets the input key the buffer memory class will load dynamically. -func (m *ConversationBuffer) MemoryVariables() []string { +func (m *ConversationBuffer) MemoryVariables(context.Context) []string { return []string{m.MemoryKey} } @@ -39,8 +40,10 @@ func (m *ConversationBuffer) MemoryVariables() []string { // are returned in a map with the key specified in the MemoryKey field. This key defaults to // "history". If ReturnMessages is set to true the output is a slice of schema.ChatMessage. Otherwise // the output is a buffer string of the chat messages. -func (m *ConversationBuffer) LoadMemoryVariables(map[string]any) (map[string]any, error) { - messages, err := m.ChatHistory.Messages() +func (m *ConversationBuffer) LoadMemoryVariables( + ctx context.Context, _ map[string]any, +) (map[string]any, error) { + messages, err := m.ChatHistory.Messages(ctx) if err != nil { return nil, err } @@ -68,12 +71,16 @@ func (m *ConversationBuffer) LoadMemoryVariables(map[string]any) (map[string]any // input key must be a key in the input values and the output key must be a key in the output // values. The values in the input and output values used to save a user and ai message must // be strings. -func (m *ConversationBuffer) SaveContext(inputValues map[string]any, outputValues map[string]any) error { +func (m *ConversationBuffer) SaveContext( + ctx context.Context, + inputValues map[string]any, + outputValues map[string]any, +) error { userInputValue, err := getInputValue(inputValues, m.InputKey) if err != nil { return err } - err = m.ChatHistory.AddUserMessage(userInputValue) + err = m.ChatHistory.AddUserMessage(ctx, userInputValue) if err != nil { return err } @@ -82,7 +89,7 @@ func (m *ConversationBuffer) SaveContext(inputValues map[string]any, outputValue if err != nil { return err } - err = m.ChatHistory.AddAIMessage(aiOutputValue) + err = m.ChatHistory.AddAIMessage(ctx, aiOutputValue) if err != nil { return err } @@ -95,7 +102,7 @@ func (m *ConversationBuffer) Clear() error { return m.ChatHistory.Clear() } -func (m *ConversationBuffer) GetMemoryKey() string { +func (m *ConversationBuffer) GetMemoryKey(context.Context) string { return m.MemoryKey } diff --git a/memory/buffer_test.go b/memory/buffer_test.go index 97562269f..a0c84eeb5 100644 --- a/memory/buffer_test.go +++ b/memory/buffer_test.go @@ -1,10 +1,12 @@ package memory import ( + "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tmc/langchaingo/schema" ) @@ -12,15 +14,15 @@ func TestBufferMemory(t *testing.T) { t.Parallel() m := NewConversationBuffer() - result1, err := m.LoadMemoryVariables(map[string]any{}) + result1, err := m.LoadMemoryVariables(context.Background(), map[string]any{}) require.NoError(t, err) expected1 := map[string]any{"history": ""} assert.Equal(t, expected1, result1) - err = m.SaveContext(map[string]any{"foo": "bar"}, map[string]any{"bar": "foo"}) + err = m.SaveContext(context.Background(), map[string]any{"foo": "bar"}, map[string]any{"bar": "foo"}) require.NoError(t, err) - result2, err := m.LoadMemoryVariables(map[string]any{}) + result2, err := m.LoadMemoryVariables(context.Background(), map[string]any{}) require.NoError(t, err) expected2 := map[string]any{"history": "Human: bar\nAI: foo"} @@ -33,14 +35,14 @@ func TestBufferMemoryReturnMessage(t *testing.T) { m := NewConversationBuffer() m.ReturnMessages = true expected1 := map[string]any{"history": []schema.ChatMessage{}} - result1, err := m.LoadMemoryVariables(map[string]any{}) + result1, err := m.LoadMemoryVariables(context.Background(), map[string]any{}) require.NoError(t, err) assert.Equal(t, expected1, result1) - err = m.SaveContext(map[string]any{"foo": "bar"}, map[string]any{"bar": "foo"}) + err = m.SaveContext(context.Background(), map[string]any{"foo": "bar"}, map[string]any{"bar": "foo"}) require.NoError(t, err) - result2, err := m.LoadMemoryVariables(map[string]any{}) + result2, err := m.LoadMemoryVariables(context.Background(), map[string]any{}) require.NoError(t, err) expectedChatHistory := NewChatMessageHistory( @@ -50,7 +52,7 @@ func TestBufferMemoryReturnMessage(t *testing.T) { }), ) - messages, err := expectedChatHistory.Messages() + messages, err := expectedChatHistory.Messages(context.Background()) assert.NoError(t, err) expected2 := map[string]any{"history": messages} assert.Equal(t, expected2, result2) @@ -66,7 +68,7 @@ func TestBufferMemoryWithPreLoadedHistory(t *testing.T) { }), ))) - result, err := m.LoadMemoryVariables(map[string]any{}) + result, err := m.LoadMemoryVariables(context.Background(), map[string]any{}) require.NoError(t, err) expected := map[string]any{"history": "Human: bar\nAI: foo"} assert.Equal(t, expected, result) @@ -76,15 +78,15 @@ type testChatMessageHistory struct{} var _ schema.ChatMessageHistory = testChatMessageHistory{} -func (t testChatMessageHistory) AddUserMessage(_ string) error { +func (t testChatMessageHistory) AddUserMessage(context.Context, string) error { return nil } -func (t testChatMessageHistory) AddAIMessage(_ string) error { +func (t testChatMessageHistory) AddAIMessage(context.Context, string) error { return nil } -func (t testChatMessageHistory) AddMessage(_ schema.ChatMessage) error { +func (t testChatMessageHistory) AddMessage(context.Context, schema.ChatMessage) error { return nil } @@ -92,11 +94,11 @@ func (t testChatMessageHistory) Clear() error { return nil } -func (t testChatMessageHistory) SetMessages(_ []schema.ChatMessage) error { +func (t testChatMessageHistory) SetMessages(context.Context, []schema.ChatMessage) error { return nil } -func (t testChatMessageHistory) Messages() ([]schema.ChatMessage, error) { +func (t testChatMessageHistory) Messages(context.Context) ([]schema.ChatMessage, error) { return []schema.ChatMessage{ schema.HumanChatMessage{Content: "user message test"}, schema.AIChatMessage{Content: "ai message test"}, @@ -109,7 +111,7 @@ func TestBufferMemoryWithChatHistoryOption(t *testing.T) { chatMessageHistory := testChatMessageHistory{} m := NewConversationBuffer(WithChatHistory(chatMessageHistory)) - result, err := m.LoadMemoryVariables(map[string]any{}) + result, err := m.LoadMemoryVariables(context.Background(), map[string]any{}) require.NoError(t, err) expected := map[string]any{"history": "Human: user message test\nAI: ai message test"} assert.Equal(t, expected, result) diff --git a/memory/chat.go b/memory/chat.go index 4bd6225ae..299ce7ae9 100644 --- a/memory/chat.go +++ b/memory/chat.go @@ -1,6 +1,10 @@ package memory -import "github.com/tmc/langchaingo/schema" +import ( + "context" + + "github.com/tmc/langchaingo/schema" +) // ChatMessageHistory is a struct that stores chat messages. type ChatMessageHistory struct { @@ -16,18 +20,18 @@ func NewChatMessageHistory(options ...ChatMessageHistoryOption) *ChatMessageHist } // Messages returns all messages stored. -func (h *ChatMessageHistory) Messages() ([]schema.ChatMessage, error) { +func (h *ChatMessageHistory) Messages(_ context.Context) ([]schema.ChatMessage, error) { return h.messages, nil } // AddAIMessage adds an AIMessage to the chat message history. -func (h *ChatMessageHistory) AddAIMessage(text string) error { +func (h *ChatMessageHistory) AddAIMessage(_ context.Context, text string) error { h.messages = append(h.messages, schema.AIChatMessage{Content: text}) return nil } // AddUserMessage adds an user to the chat message history. -func (h *ChatMessageHistory) AddUserMessage(text string) error { +func (h *ChatMessageHistory) AddUserMessage(_ context.Context, text string) error { h.messages = append(h.messages, schema.HumanChatMessage{Content: text}) return nil } @@ -37,12 +41,12 @@ func (h *ChatMessageHistory) Clear() error { return nil } -func (h *ChatMessageHistory) AddMessage(message schema.ChatMessage) error { +func (h *ChatMessageHistory) AddMessage(_ context.Context, message schema.ChatMessage) error { h.messages = append(h.messages, message) return nil } -func (h *ChatMessageHistory) SetMessages(messages []schema.ChatMessage) error { +func (h *ChatMessageHistory) SetMessages(_ context.Context, messages []schema.ChatMessage) error { h.messages = messages return nil } diff --git a/memory/chat_test.go b/memory/chat_test.go index c6e7122d1..077c904f5 100644 --- a/memory/chat_test.go +++ b/memory/chat_test.go @@ -1,9 +1,11 @@ package memory import ( + "context" "testing" "github.com/stretchr/testify/assert" + "github.com/tmc/langchaingo/schema" ) @@ -11,12 +13,12 @@ func TestChatMessageHistory(t *testing.T) { t.Parallel() h := NewChatMessageHistory() - err := h.AddAIMessage("foo") + err := h.AddAIMessage(context.Background(), "foo") assert.NoError(t, err) - err = h.AddUserMessage("bar") + err = h.AddUserMessage(context.Background(), "bar") assert.NoError(t, err) - messages, err := h.Messages() + messages, err := h.Messages(context.Background()) assert.NoError(t, err) assert.Equal(t, []schema.ChatMessage{ @@ -30,10 +32,10 @@ func TestChatMessageHistory(t *testing.T) { schema.SystemChatMessage{Content: "bar"}, }), ) - err = h.AddUserMessage("zoo") + err = h.AddUserMessage(context.Background(), "zoo") assert.NoError(t, err) - messages, err = h.Messages() + messages, err = h.Messages(context.Background()) assert.NoError(t, err) assert.Equal(t, []schema.ChatMessage{ diff --git a/memory/simple.go b/memory/simple.go index 244db7e9e..2cbca2cbf 100644 --- a/memory/simple.go +++ b/memory/simple.go @@ -1,6 +1,8 @@ package memory import ( + "context" + "github.com/tmc/langchaingo/schema" ) @@ -15,15 +17,15 @@ func NewSimple() Simple { // Statically assert that Simple implement the memory interface. var _ schema.Memory = Simple{} -func (m Simple) MemoryVariables() []string { +func (m Simple) MemoryVariables(context.Context) []string { return nil } -func (m Simple) LoadMemoryVariables(map[string]any) (map[string]any, error) { - return make(map[string]any, 0), nil +func (m Simple) LoadMemoryVariables(context.Context, map[string]any) (map[string]any, error) { + return make(map[string]any), nil } -func (m Simple) SaveContext(map[string]any, map[string]any) error { +func (m Simple) SaveContext(context.Context, map[string]any, map[string]any) error { return nil } @@ -31,6 +33,6 @@ func (m Simple) Clear() error { return nil } -func (m Simple) GetMemoryKey() string { +func (m Simple) GetMemoryKey(context.Context) string { return "" } diff --git a/memory/token_buffer.go b/memory/token_buffer.go index 8bb0c5ffe..f8e4b8753 100644 --- a/memory/token_buffer.go +++ b/memory/token_buffer.go @@ -1,6 +1,8 @@ package memory import ( + "context" + "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/schema" ) @@ -31,22 +33,26 @@ func NewConversationTokenBuffer( } // MemoryVariables uses ConversationBuffer method for memory variables. -func (tb *ConversationTokenBuffer) MemoryVariables() []string { - return tb.ConversationBuffer.MemoryVariables() +func (tb *ConversationTokenBuffer) MemoryVariables(ctx context.Context) []string { + return tb.ConversationBuffer.MemoryVariables(ctx) } // LoadMemoryVariables uses ConversationBuffer method for loading memory variables. -func (tb *ConversationTokenBuffer) LoadMemoryVariables(inputs map[string]any) (map[string]any, error) { - return tb.ConversationBuffer.LoadMemoryVariables(inputs) +func (tb *ConversationTokenBuffer) LoadMemoryVariables( + ctx context.Context, inputs map[string]any, +) (map[string]any, error) { + return tb.ConversationBuffer.LoadMemoryVariables(ctx, inputs) } // SaveContext uses ConversationBuffer method for saving context and prunes memory buffer if needed. -func (tb *ConversationTokenBuffer) SaveContext(inputValues map[string]any, outputValues map[string]any) error { - err := tb.ConversationBuffer.SaveContext(inputValues, outputValues) +func (tb *ConversationTokenBuffer) SaveContext( + ctx context.Context, inputValues map[string]any, outputValues map[string]any, +) error { + err := tb.ConversationBuffer.SaveContext(ctx, inputValues, outputValues) if err != nil { return err } - currBufferLength, err := tb.getNumTokensFromMessages() + currBufferLength, err := tb.getNumTokensFromMessages(ctx) if err != nil { return err } @@ -55,7 +61,7 @@ func (tb *ConversationTokenBuffer) SaveContext(inputValues map[string]any, outpu // while currBufferLength is greater than MaxTokenLimit we keep removing messages from the memory // from the oldest for currBufferLength > tb.MaxTokenLimit { - messages, err := tb.ChatHistory.Messages() + messages, err := tb.ChatHistory.Messages(ctx) if err != nil { return err } @@ -64,12 +70,12 @@ func (tb *ConversationTokenBuffer) SaveContext(inputValues map[string]any, outpu break } - err = tb.ChatHistory.SetMessages(append(messages[:0], messages[1:]...)) + err = tb.ChatHistory.SetMessages(ctx, append(messages[:0], messages[1:]...)) if err != nil { return err } - currBufferLength, err = tb.getNumTokensFromMessages() + currBufferLength, err = tb.getNumTokensFromMessages(ctx) if err != nil { return err } @@ -84,8 +90,8 @@ func (tb *ConversationTokenBuffer) Clear() error { return tb.ConversationBuffer.Clear() } -func (tb *ConversationTokenBuffer) getNumTokensFromMessages() (int, error) { - messages, err := tb.ChatHistory.Messages() +func (tb *ConversationTokenBuffer) getNumTokensFromMessages(ctx context.Context) (int, error) { + messages, err := tb.ChatHistory.Messages(ctx) if err != nil { return 0, err } diff --git a/memory/token_buffer_test.go b/memory/token_buffer_test.go index 2a93f2493..36b2eb15b 100644 --- a/memory/token_buffer_test.go +++ b/memory/token_buffer_test.go @@ -1,6 +1,7 @@ package memory import ( + "context" "os" "testing" @@ -21,15 +22,15 @@ func TestTokenBufferMemory(t *testing.T) { require.NoError(t, err) m := NewConversationTokenBuffer(llm, 2000) - result1, err := m.LoadMemoryVariables(map[string]any{}) + result1, err := m.LoadMemoryVariables(context.Background(), map[string]any{}) require.NoError(t, err) expected1 := map[string]any{"history": ""} assert.Equal(t, expected1, result1) - err = m.SaveContext(map[string]any{"foo": "bar"}, map[string]any{"bar": "foo"}) + err = m.SaveContext(context.Background(), map[string]any{"foo": "bar"}, map[string]any{"bar": "foo"}) require.NoError(t, err) - result2, err := m.LoadMemoryVariables(map[string]any{}) + result2, err := m.LoadMemoryVariables(context.Background(), map[string]any{}) require.NoError(t, err) expected2 := map[string]any{"history": "Human: bar\nAI: foo"} @@ -48,14 +49,14 @@ func TestTokenBufferMemoryReturnMessage(t *testing.T) { m := NewConversationTokenBuffer(llm, 2000, WithReturnMessages(true)) expected1 := map[string]any{"history": []schema.ChatMessage{}} - result1, err := m.LoadMemoryVariables(map[string]any{}) + result1, err := m.LoadMemoryVariables(context.Background(), map[string]any{}) require.NoError(t, err) assert.Equal(t, expected1, result1) - err = m.SaveContext(map[string]any{"foo": "bar"}, map[string]any{"bar": "foo"}) + err = m.SaveContext(context.Background(), map[string]any{"foo": "bar"}, map[string]any{"bar": "foo"}) require.NoError(t, err) - result2, err := m.LoadMemoryVariables(map[string]any{}) + result2, err := m.LoadMemoryVariables(context.Background(), map[string]any{}) require.NoError(t, err) expectedChatHistory := NewChatMessageHistory( @@ -65,7 +66,7 @@ func TestTokenBufferMemoryReturnMessage(t *testing.T) { }), ) - messages, err := expectedChatHistory.Messages() + messages, err := expectedChatHistory.Messages(context.Background()) require.NoError(t, err) expected2 := map[string]any{"history": messages} assert.Equal(t, expected2, result2) @@ -88,7 +89,7 @@ func TestTokenBufferMemoryWithPreLoadedHistory(t *testing.T) { }), ))) - result, err := m.LoadMemoryVariables(map[string]any{}) + result, err := m.LoadMemoryVariables(context.Background(), map[string]any{}) require.NoError(t, err) expected := map[string]any{"history": "Human: bar\nAI: foo"} assert.Equal(t, expected, result) diff --git a/schema/chat_message_history.go b/schema/chat_message_history.go index 2c573e9f9..1a30fb004 100644 --- a/schema/chat_message_history.go +++ b/schema/chat_message_history.go @@ -1,22 +1,24 @@ package schema +import "context" + // ChatMessageHistory is the interface for chat history in memory/store. type ChatMessageHistory interface { // AddUserMessage Convenience method for adding a human message string to the store. - AddUserMessage(message string) error + AddUserMessage(ctx context.Context, message string) error // AddAIMessage Convenience method for adding an AI message string to the store. - AddAIMessage(message string) error + AddAIMessage(ctx context.Context, message string) error // AddMessage Add a Message object to the store. - AddMessage(message ChatMessage) error + AddMessage(ctx context.Context, message ChatMessage) error // Clear Remove all messages from the store. Clear() error // Messages get all messages from the store - Messages() ([]ChatMessage, error) + Messages(ctx context.Context) ([]ChatMessage, error) // SetMessages replaces existing messages in the store - SetMessages(messages []ChatMessage) error + SetMessages(ctx context.Context, messages []ChatMessage) error } diff --git a/schema/memory.go b/schema/memory.go index c56221636..8579d2f0f 100644 --- a/schema/memory.go +++ b/schema/memory.go @@ -1,16 +1,18 @@ package schema +import "context" + // Memory is the interface for memory in chains. type Memory interface { // GetMemoryKey getter for memory key. - GetMemoryKey() string + GetMemoryKey(ctx context.Context) string // MemoryVariables Input keys this memory class will load dynamically. - MemoryVariables() []string + MemoryVariables(ctx context.Context) []string // LoadMemoryVariables Return key-value pairs given the text input to the chain. // If None, return all memories - LoadMemoryVariables(inputs map[string]any) (map[string]any, error) + LoadMemoryVariables(ctx context.Context, inputs map[string]any) (map[string]any, error) // SaveContext Save the context of this model run to memory. - SaveContext(inputs map[string]any, outputs map[string]any) error + SaveContext(ctx context.Context, inputs map[string]any, outputs map[string]any) error // Clear memory contents. Clear() error }