Skip to content

Commit

Permalink
Merge pull request #257 from FluffyKebab/callbacks
Browse files Browse the repository at this point in the history
callbacks: add standard interface and logger
  • Loading branch information
tmc authored Aug 26, 2023
2 parents caa5453 + c9460de commit 9f51407
Show file tree
Hide file tree
Showing 62 changed files with 489 additions and 122 deletions.
63 changes: 44 additions & 19 deletions agents/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"strings"

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/tools"
Expand All @@ -14,15 +15,19 @@ const _intermediateStepsOutputKey = "intermediateSteps"

// Executor is the chain responsible for running agents.
type Executor struct {
Agent Agent
Tools []tools.Tool
Memory schema.Memory
Agent Agent
Tools []tools.Tool
Memory schema.Memory
CallbacksHandler callbacks.Handler

MaxIterations int
ReturnIntermediateSteps bool
}

var _ chains.Chain = Executor{}
var (
_ chains.Chain = Executor{}
_ callbacks.HandlerHaver = Executor{}
)

// NewExecutor creates a new agent executor with a agent and the tools the agent can use.
func NewExecutor(agent Agent, tools []tools.Tool, opts ...CreationOption) Executor {
Expand All @@ -37,6 +42,7 @@ func NewExecutor(agent Agent, tools []tools.Tool, opts ...CreationOption) Execut
Memory: options.memory,
MaxIterations: options.maxIterations,
ReturnIntermediateSteps: options.returnIntermediateSteps,
CallbacksHandler: options.callbacksHandler,
}
}

Expand All @@ -63,30 +69,45 @@ func (e Executor) Call(ctx context.Context, inputValues map[string]any, _ ...cha
}

for _, action := range actions {
tool, ok := nameToTool[strings.ToUpper(action.Tool)]
if !ok {
steps = append(steps, schema.AgentStep{
Action: action,
Observation: fmt.Sprintf("%s is not a valid tool, try another one", action.Tool),
})
continue
}

observation, err := tool.Call(ctx, action.ToolInput)
steps, err = e.doAction(ctx, steps, nameToTool, action)
if err != nil {
return nil, err
}

steps = append(steps, schema.AgentStep{
Action: action,
Observation: observation,
})
}
}

return nil, ErrNotFinished
}

func (e Executor) doAction(
ctx context.Context,
steps []schema.AgentStep,
nameToTool map[string]tools.Tool,
action schema.AgentAction,
) ([]schema.AgentStep, error) {
if e.CallbacksHandler != nil {
e.CallbacksHandler.HandleAgentAction(ctx, action)
}

tool, ok := nameToTool[strings.ToUpper(action.Tool)]
if !ok {
return append(steps, schema.AgentStep{
Action: action,
Observation: fmt.Sprintf("%s is not a valid tool, try another one", action.Tool),
}), nil
}

observation, err := tool.Call(ctx, action.ToolInput)
if err != nil {
return nil, err
}

return append(steps, schema.AgentStep{
Action: action,
Observation: observation,
}), nil
}

func (e Executor) getReturn(finish *schema.AgentFinish, steps []schema.AgentStep) map[string]any {
if e.ReturnIntermediateSteps {
finish.ReturnValues[_intermediateStepsOutputKey] = steps
Expand All @@ -110,6 +131,10 @@ func (e Executor) GetMemory() schema.Memory { //nolint:ireturn
return e.Memory
}

func (e Executor) GetCallbackHandler() callbacks.Handler { //nolint:ireturn
return e.CallbacksHandler
}

func inputsToString(inputValues map[string]any) (map[string]string, error) {
inputs := make(map[string]string, len(inputValues))
for key, value := range inputValues {
Expand Down
8 changes: 8 additions & 0 deletions agents/options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package agents

import (
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/memory"
"github.com/tmc/langchaingo/prompts"
"github.com/tmc/langchaingo/schema"
Expand All @@ -10,6 +11,7 @@ import (
type CreationOptions struct {
prompt prompts.PromptTemplate
memory schema.Memory
callbacksHandler callbacks.Handler
maxIterations int
returnIntermediateSteps bool
outputKey string
Expand Down Expand Up @@ -131,3 +133,9 @@ func WithMemory(m schema.Memory) CreationOption {
co.memory = m
}
}

func WithCallbacksHandler(handler callbacks.Handler) CreationOption {
return func(co *CreationOptions) {
co.callbacksHandler = handler
}
}
28 changes: 28 additions & 0 deletions callbacks/callbacks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package callbacks

import (
"context"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
)

// Handler is the interface that allows for hooking into specific parts of an
// LLM application.
type Handler interface {
HandleText(ctx context.Context, text string)
HandleLLMStart(ctx context.Context, prompts []string)
HandleLLMEnd(ctx context.Context, output llms.LLMResult)
HandleChainStart(ctx context.Context, inputs map[string]any)
HandleChainEnd(ctx context.Context, outputs map[string]any)
HandleToolStart(ctx context.Context, input string)
HandleToolEnd(ctx context.Context, output string)
HandleAgentAction(ctx context.Context, action schema.AgentAction)
HandleRetrieverStart(ctx context.Context, query string)
HandleRetrieverEnd(ctx context.Context, documents []schema.Document)
}

// HandlerHaver is an interface used to get callbacks handler.
type HandlerHaver interface {
GetCallbackHandler() Handler
}
4 changes: 4 additions & 0 deletions callbacks/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// Package callbacks includes a standard interface for hooking into various
// stages of your LLM application. The package contains an implementation of
// this interface that prints to the standard output.
package callbacks
84 changes: 84 additions & 0 deletions callbacks/log.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
//nolint:forbidigo
package callbacks

import (
"context"
"fmt"
"strings"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
)

// LogHandler is a callback handler that prints to the standard output.
type LogHandler struct{}

var _ Handler = LogHandler{}

func (l LogHandler) HandleText(_ context.Context, text string) {
fmt.Println(text)
}

func (l LogHandler) HandleLLMStart(_ context.Context, prompts []string) {
fmt.Println("Entering LLM with prompts:", prompts)
}

func (l LogHandler) HandleLLMEnd(_ context.Context, output llms.LLMResult) {
fmt.Println("Exiting LLM with results:", formatLLMResult(output))
}

func (l LogHandler) HandleChainStart(_ context.Context, inputs map[string]any) {
fmt.Println("Entering chain with inputs:", formatChainValues(inputs))
}

func (l LogHandler) HandleChainEnd(_ context.Context, outputs map[string]any) {
fmt.Println("Exiting chain with outputs:", formatChainValues(outputs))
}

func (l LogHandler) HandleToolStart(_ context.Context, input string) {
fmt.Println("Entering tool with input:", removeNewLines(input))
}

func (l LogHandler) HandleToolEnd(_ context.Context, output string) {
fmt.Println("Exiting tool with output:", removeNewLines(output))
}

func (l LogHandler) HandleAgentAction(_ context.Context, action schema.AgentAction) {
fmt.Println("Agent selected action:", formatAgentAction(action))
}

func (l LogHandler) HandleRetrieverStart(_ context.Context, query string) {
fmt.Println("Entering retriever with query:", removeNewLines(query))
}

func (l LogHandler) HandleRetrieverEnd(_ context.Context, documents []schema.Document) {
fmt.Println("Exiting retirer with documents:", documents)
}

func formatChainValues(values map[string]any) string {
output := ""
for key, value := range values {
output += fmt.Sprintf("\"%s\" : \"%s\", ", removeNewLines(key), removeNewLines(value))
}

return output
}

func formatLLMResult(output llms.LLMResult) string {
results := "[ "
for i := 0; i < len(output.Generations); i++ {
for j := 0; j < len(output.Generations[i]); j++ {
results += output.Generations[i][j].Text
}
}

return results + " ]"
}

func formatAgentAction(action schema.AgentAction) string {
return fmt.Sprintf("\"%s\" with input \"%s\"", removeNewLines(action.Tool), removeNewLines(action.ToolInput))
}

func removeNewLines(s any) string {
return strings.ReplaceAll(fmt.Sprint(s), "\n", " ")
}
17 changes: 17 additions & 0 deletions chains/chains.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"sync"

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/schema"
)

Expand Down Expand Up @@ -41,6 +42,11 @@ func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...C
fullValues[key] = value
}

callbacksHandler := getChainCallbackHandler(c)
if callbacksHandler != nil {
callbacksHandler.HandleChainStart(ctx, inputValues)
}

if err := validateInputs(c, fullValues); err != nil {
return nil, err
}
Expand All @@ -53,6 +59,10 @@ func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...C
return nil, err
}

if callbacksHandler != nil {
callbacksHandler.HandleChainEnd(ctx, outputValues)
}

err = c.GetMemory().SaveContext(ctx, inputValues, outputValues)
if err != nil {
return nil, err
Expand Down Expand Up @@ -225,3 +235,10 @@ func validateOutputs(c Chain, outputValues map[string]any) error {
}
return nil
}

func getChainCallbackHandler(c Chain) callbacks.Handler {
if handlerHaver, ok := c.(callbacks.HandlerHaver); ok {
return handlerHaver.GetCallbackHandler()
}
return nil
}
19 changes: 14 additions & 5 deletions chains/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package chains
import (
"context"

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/memory"
"github.com/tmc/langchaingo/outputparser"
Expand All @@ -13,15 +14,19 @@ import (
const _llmChainDefaultOutputKey = "text"

type LLMChain struct {
Prompt prompts.FormatPrompter
LLM llms.LanguageModel
Memory schema.Memory
OutputParser schema.OutputParser[any]
Prompt prompts.FormatPrompter
LLM llms.LanguageModel
Memory schema.Memory
CallbacksHandler callbacks.Handler
OutputParser schema.OutputParser[any]

OutputKey string
}

var _ Chain = &LLMChain{}
var (
_ Chain = &LLMChain{}
_ callbacks.HandlerHaver = &LLMChain{}
)

// NewLLMChain creates a new LLMChain with an llm and a prompt.
func NewLLMChain(llm llms.LanguageModel, prompt prompts.FormatPrompter) *LLMChain {
Expand Down Expand Up @@ -69,6 +74,10 @@ func (c LLMChain) GetMemory() schema.Memory { //nolint:ireturn
return c.Memory //nolint:ireturn
}

func (c LLMChain) GetCallbackHandler() callbacks.Handler { //nolint:ireturn
return c.CallbacksHandler
}

// GetInputKeys returns the expected input keys.
func (c LLMChain) GetInputKeys() []string {
return append([]string{}, c.Prompt.GetInputVariables()...)
Expand Down
2 changes: 2 additions & 0 deletions chains/llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/prompts"
)
Expand All @@ -18,6 +19,7 @@ func TestLLMChain(t *testing.T) {
}
model, err := openai.New()
require.NoError(t, err)
model.CallbacksHandler = callbacks.LogHandler{}

prompt := prompts.NewPromptTemplate(
"What is the capital of {{.country}}",
Expand Down
2 changes: 1 addition & 1 deletion examples/anthropic-completion-example/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module anthropic-completion-example

go 1.19

require github.com/tmc/langchaingo v0.0.0-20230729231952-1f3948210849
require github.com/tmc/langchaingo v0.0.0-20230826015154-aa97aec400c0

require (
github.com/dlclark/regexp2 v1.8.1 // indirect
Expand Down
4 changes: 2 additions & 2 deletions examples/anthropic-completion-example/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHt
github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/tmc/langchaingo v0.0.0-20230729231952-1f3948210849 h1:y4TnpS57FeE5QzBzV2wKysVxWvkuMsha12yGHPVVAFo=
github.com/tmc/langchaingo v0.0.0-20230729231952-1f3948210849/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E=
github.com/tmc/langchaingo v0.0.0-20230826015154-aa97aec400c0 h1:7pmW0coaYnLm4evqJ+QR10EgDU1ku+xGOpcHBMuvE7E=
github.com/tmc/langchaingo v0.0.0-20230826015154-aa97aec400c0/go.mod h1:fd7jP67Fwvcr+i7J+oAZbrh2aiekUuBLVqW/vgDslnw=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
2 changes: 1 addition & 1 deletion examples/cohere-llm-example/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module basic-llm-example

go 1.19

require github.com/tmc/langchaingo v0.0.0-20230729231952-1f3948210849
require github.com/tmc/langchaingo v0.0.0-20230826015154-aa97aec400c0

require (
github.com/cohere-ai/tokenizer v1.1.2 // indirect
Expand Down
4 changes: 2 additions & 2 deletions examples/cohere-llm-example/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHt
github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/tmc/langchaingo v0.0.0-20230729231952-1f3948210849 h1:y4TnpS57FeE5QzBzV2wKysVxWvkuMsha12yGHPVVAFo=
github.com/tmc/langchaingo v0.0.0-20230729231952-1f3948210849/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E=
github.com/tmc/langchaingo v0.0.0-20230826015154-aa97aec400c0 h1:7pmW0coaYnLm4evqJ+QR10EgDU1ku+xGOpcHBMuvE7E=
github.com/tmc/langchaingo v0.0.0-20230826015154-aa97aec400c0/go.mod h1:fd7jP67Fwvcr+i7J+oAZbrh2aiekUuBLVqW/vgDslnw=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Loading

0 comments on commit 9f51407

Please sign in to comment.