Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

callbacks: add standard interface and logger #257

Merged
merged 12 commits into from
Aug 26, 2023
63 changes: 44 additions & 19 deletions agents/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"strings"

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/tools"
Expand All @@ -14,15 +15,19 @@ const _intermediateStepsOutputKey = "intermediateSteps"

// Executor is the chain responsible for running agents.
type Executor struct {
Agent Agent
Tools []tools.Tool
Memory schema.Memory
Agent Agent
Tools []tools.Tool
Memory schema.Memory
CallbacksHandler callbacks.Handler

MaxIterations int
ReturnIntermediateSteps bool
}

var _ chains.Chain = Executor{}
var (
_ chains.Chain = Executor{}
_ callbacks.HandlerHaver = Executor{}
)

// NewExecutor creates a new agent executor with a agent and the tools the agent can use.
func NewExecutor(agent Agent, tools []tools.Tool, opts ...CreationOption) Executor {
Expand All @@ -37,6 +42,7 @@ func NewExecutor(agent Agent, tools []tools.Tool, opts ...CreationOption) Execut
Memory: options.memory,
MaxIterations: options.maxIterations,
ReturnIntermediateSteps: options.returnIntermediateSteps,
CallbacksHandler: options.callbacksHandler,
}
}

Expand All @@ -63,30 +69,45 @@ func (e Executor) Call(ctx context.Context, inputValues map[string]any, _ ...cha
}

for _, action := range actions {
tool, ok := nameToTool[strings.ToUpper(action.Tool)]
if !ok {
steps = append(steps, schema.AgentStep{
Action: action,
Observation: fmt.Sprintf("%s is not a valid tool, try another one", action.Tool),
})
continue
}

observation, err := tool.Call(ctx, action.ToolInput)
steps, err = e.doAction(ctx, steps, nameToTool, action)
if err != nil {
return nil, err
}

steps = append(steps, schema.AgentStep{
Action: action,
Observation: observation,
})
}
}

return nil, ErrNotFinished
}

func (e Executor) doAction(
ctx context.Context,
steps []schema.AgentStep,
nameToTool map[string]tools.Tool,
action schema.AgentAction,
) ([]schema.AgentStep, error) {
if e.CallbacksHandler != nil {
e.CallbacksHandler.HandleAgentAction(ctx, action)
}

tool, ok := nameToTool[strings.ToUpper(action.Tool)]
if !ok {
return append(steps, schema.AgentStep{
Action: action,
Observation: fmt.Sprintf("%s is not a valid tool, try another one", action.Tool),
}), nil
}

observation, err := tool.Call(ctx, action.ToolInput)
if err != nil {
return nil, err
}

return append(steps, schema.AgentStep{
Action: action,
Observation: observation,
}), nil
}

func (e Executor) getReturn(finish *schema.AgentFinish, steps []schema.AgentStep) map[string]any {
if e.ReturnIntermediateSteps {
finish.ReturnValues[_intermediateStepsOutputKey] = steps
Expand All @@ -110,6 +131,10 @@ func (e Executor) GetMemory() schema.Memory { //nolint:ireturn
return e.Memory
}

func (e Executor) GetCallbackHandler() callbacks.Handler { //nolint:ireturn
return e.CallbacksHandler
}

func inputsToString(inputValues map[string]any) (map[string]string, error) {
inputs := make(map[string]string, len(inputValues))
for key, value := range inputValues {
Expand Down
8 changes: 8 additions & 0 deletions agents/options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package agents

import (
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/memory"
"github.com/tmc/langchaingo/prompts"
"github.com/tmc/langchaingo/schema"
Expand All @@ -10,6 +11,7 @@ import (
type CreationOptions struct {
prompt prompts.PromptTemplate
memory schema.Memory
callbacksHandler callbacks.Handler
maxIterations int
returnIntermediateSteps bool
outputKey string
Expand Down Expand Up @@ -131,3 +133,9 @@ func WithMemory(m schema.Memory) CreationOption {
co.memory = m
}
}

func WithCallbacksHandler(handler callbacks.Handler) CreationOption {
return func(co *CreationOptions) {
co.callbacksHandler = handler
}
}
28 changes: 28 additions & 0 deletions callbacks/callbacks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package callbacks

import (
"context"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
)

// Handler is the interface that allows for hooking into specific parts of an
// LLM application.
type Handler interface {
HandleText(ctx context.Context, text string)
HandleLLMStart(ctx context.Context, prompts []string)
HandleLLMEnd(ctx context.Context, output llms.LLMResult)
HandleChainStart(ctx context.Context, inputs map[string]any)
HandleChainEnd(ctx context.Context, outputs map[string]any)
HandleToolStart(ctx context.Context, input string)
HandleToolEnd(ctx context.Context, output string)
HandleAgentAction(ctx context.Context, action schema.AgentAction)
HandleRetrieverStart(ctx context.Context, query string)
HandleRetrieverEnd(ctx context.Context, documents []schema.Document)
}

// HandlerHaver is an interface used to get callbacks handler.
type HandlerHaver interface {
GetCallbackHandler() Handler
}
4 changes: 4 additions & 0 deletions callbacks/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// Package callbacks includes a standard interface for hooking into various
// stages of your LLM application. The package contains an implementation of
// this interface that prints to the standard output.
package callbacks
84 changes: 84 additions & 0 deletions callbacks/log.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
//nolint:forbidigo
package callbacks

import (
"context"
"fmt"
"strings"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
)

// LogHandler is a callback handler that prints to the standard output.
type LogHandler struct{}

var _ Handler = LogHandler{}

func (l LogHandler) HandleText(_ context.Context, text string) {
fmt.Println(text)
}

func (l LogHandler) HandleLLMStart(_ context.Context, prompts []string) {
fmt.Println("Entering LLM with prompts:", prompts)
}

func (l LogHandler) HandleLLMEnd(_ context.Context, output llms.LLMResult) {
fmt.Println("Exiting LLM with results:", formatLLMResult(output))
}

func (l LogHandler) HandleChainStart(_ context.Context, inputs map[string]any) {
fmt.Println("Entering chain with inputs:", formatChainValues(inputs))
}

func (l LogHandler) HandleChainEnd(_ context.Context, outputs map[string]any) {
fmt.Println("Exiting chain with outputs:", formatChainValues(outputs))
}

func (l LogHandler) HandleToolStart(_ context.Context, input string) {
fmt.Println("Entering tool with input:", removeNewLines(input))
}

func (l LogHandler) HandleToolEnd(_ context.Context, output string) {
fmt.Println("Exiting tool with output:", removeNewLines(output))
}

func (l LogHandler) HandleAgentAction(_ context.Context, action schema.AgentAction) {
fmt.Println("Agent selected action:", formatAgentAction(action))
}

func (l LogHandler) HandleRetrieverStart(_ context.Context, query string) {
fmt.Println("Entering retriever with query:", removeNewLines(query))
}

func (l LogHandler) HandleRetrieverEnd(_ context.Context, documents []schema.Document) {
fmt.Println("Exiting retirer with documents:", documents)
}

func formatChainValues(values map[string]any) string {
output := ""
for key, value := range values {
output += fmt.Sprintf("\"%s\" : \"%s\", ", removeNewLines(key), removeNewLines(value))
}

return output
}

func formatLLMResult(output llms.LLMResult) string {
results := "[ "
for i := 0; i < len(output.Generations); i++ {
for j := 0; j < len(output.Generations[i]); j++ {
results += output.Generations[i][j].Text
}
}

return results + " ]"
}

func formatAgentAction(action schema.AgentAction) string {
return fmt.Sprintf("\"%s\" with input \"%s\"", removeNewLines(action.Tool), removeNewLines(action.ToolInput))
}

func removeNewLines(s any) string {
return strings.ReplaceAll(fmt.Sprint(s), "\n", " ")
}
17 changes: 17 additions & 0 deletions chains/chains.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"sync"

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/schema"
)

Expand Down Expand Up @@ -41,6 +42,11 @@ func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...C
fullValues[key] = value
}

callbacksHandler := getChainCallbackHandler(c)
if callbacksHandler != nil {
callbacksHandler.HandleChainStart(ctx, inputValues)
}

if err := validateInputs(c, fullValues); err != nil {
return nil, err
}
Expand All @@ -53,6 +59,10 @@ func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...C
return nil, err
}

if callbacksHandler != nil {
callbacksHandler.HandleChainEnd(ctx, outputValues)
}

err = c.GetMemory().SaveContext(ctx, inputValues, outputValues)
if err != nil {
return nil, err
Expand Down Expand Up @@ -225,3 +235,10 @@ func validateOutputs(c Chain, outputValues map[string]any) error {
}
return nil
}

func getChainCallbackHandler(c Chain) callbacks.Handler {
if handlerHaver, ok := c.(callbacks.HandlerHaver); ok {
return handlerHaver.GetCallbackHandler()
}
return nil
}
19 changes: 14 additions & 5 deletions chains/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package chains
import (
"context"

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/memory"
"github.com/tmc/langchaingo/outputparser"
Expand All @@ -13,15 +14,19 @@ import (
const _llmChainDefaultOutputKey = "text"

type LLMChain struct {
Prompt prompts.FormatPrompter
LLM llms.LanguageModel
Memory schema.Memory
OutputParser schema.OutputParser[any]
Prompt prompts.FormatPrompter
LLM llms.LanguageModel
Memory schema.Memory
CallbacksHandler callbacks.Handler
OutputParser schema.OutputParser[any]

OutputKey string
}

var _ Chain = &LLMChain{}
var (
_ Chain = &LLMChain{}
_ callbacks.HandlerHaver = &LLMChain{}
)

// NewLLMChain creates a new LLMChain with an llm and a prompt.
func NewLLMChain(llm llms.LanguageModel, prompt prompts.FormatPrompter) *LLMChain {
Expand Down Expand Up @@ -69,6 +74,10 @@ func (c LLMChain) GetMemory() schema.Memory { //nolint:ireturn
return c.Memory //nolint:ireturn
}

func (c LLMChain) GetCallbackHandler() callbacks.Handler { //nolint:ireturn
return c.CallbacksHandler
}

// GetInputKeys returns the expected input keys.
func (c LLMChain) GetInputKeys() []string {
return append([]string{}, c.Prompt.GetInputVariables()...)
Expand Down
2 changes: 2 additions & 0 deletions chains/llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/prompts"
)
Expand All @@ -18,6 +19,7 @@ func TestLLMChain(t *testing.T) {
}
model, err := openai.New()
require.NoError(t, err)
model.CallbacksHandler = callbacks.LogHandler{}

prompt := prompts.NewPromptTemplate(
"What is the capital of {{.country}}",
Expand Down
2 changes: 1 addition & 1 deletion examples/anthropic-completion-example/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module anthropic-completion-example

go 1.19

require github.com/tmc/langchaingo v0.0.0-20230729231952-1f3948210849
require github.com/tmc/langchaingo v0.0.0-20230826015154-aa97aec400c0

require (
github.com/dlclark/regexp2 v1.8.1 // indirect
Expand Down
4 changes: 2 additions & 2 deletions examples/anthropic-completion-example/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHt
github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/tmc/langchaingo v0.0.0-20230729231952-1f3948210849 h1:y4TnpS57FeE5QzBzV2wKysVxWvkuMsha12yGHPVVAFo=
github.com/tmc/langchaingo v0.0.0-20230729231952-1f3948210849/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E=
github.com/tmc/langchaingo v0.0.0-20230826015154-aa97aec400c0 h1:7pmW0coaYnLm4evqJ+QR10EgDU1ku+xGOpcHBMuvE7E=
github.com/tmc/langchaingo v0.0.0-20230826015154-aa97aec400c0/go.mod h1:fd7jP67Fwvcr+i7J+oAZbrh2aiekUuBLVqW/vgDslnw=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
2 changes: 1 addition & 1 deletion examples/cohere-llm-example/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module basic-llm-example

go 1.19

require github.com/tmc/langchaingo v0.0.0-20230729231952-1f3948210849
require github.com/tmc/langchaingo v0.0.0-20230826015154-aa97aec400c0

require (
github.com/cohere-ai/tokenizer v1.1.2 // indirect
Expand Down
4 changes: 2 additions & 2 deletions examples/cohere-llm-example/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHt
github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/tmc/langchaingo v0.0.0-20230729231952-1f3948210849 h1:y4TnpS57FeE5QzBzV2wKysVxWvkuMsha12yGHPVVAFo=
github.com/tmc/langchaingo v0.0.0-20230729231952-1f3948210849/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E=
github.com/tmc/langchaingo v0.0.0-20230826015154-aa97aec400c0 h1:7pmW0coaYnLm4evqJ+QR10EgDU1ku+xGOpcHBMuvE7E=
github.com/tmc/langchaingo v0.0.0-20230826015154-aa97aec400c0/go.mod h1:fd7jP67Fwvcr+i7J+oAZbrh2aiekUuBLVqW/vgDslnw=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Loading
Loading