From c6b8f4f312e69537927a19230610a7cb6a1009d2 Mon Sep 17 00:00:00 2001 From: Juuso Lindholm Date: Sun, 16 Jun 2024 23:34:11 +0200 Subject: [PATCH] ollama: Fix JSON format bug issue when not streaming (#892) * Graceful handling when LLM spits whitespace on json mode with Ollama. * ollama: Simplify stream repr, spruce up fn calling example --------- Co-authored-by: Travis Cline --- .../ollama_functions_example.go | 33 ++++++++++--------- llms/ollama/internal/ollamaclient/types.go | 2 +- llms/ollama/ollamallm.go | 4 +-- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/examples/ollama-functions-example/ollama_functions_example.go b/examples/ollama-functions-example/ollama_functions_example.go index 4702b44ad..26f1e30d9 100644 --- a/examples/ollama-functions-example/ollama_functions_example.go +++ b/examples/ollama-functions-example/ollama_functions_example.go @@ -3,6 +3,7 @@ package main import ( "context" "encoding/json" + "flag" "fmt" "log" "os" @@ -12,10 +13,13 @@ import ( "github.com/tmc/langchaingo/llms/ollama" ) +var flagVerbose = flag.Bool("v", false, "verbose mode") + func main() { + flag.Parse() // allow specifying your own model via OLLAMA_TEST_MODEL // (same as the Ollama unit tests). - model := "mistral:instruct" + model := "llama3" if v := os.Getenv("OLLAMA_TEST_MODEL"); v != "" { model = v } @@ -31,14 +35,12 @@ func main() { var msgs []llms.MessageContent // system message defines the available tools. - msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeSystem, - systemMessage())) - msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeHuman, - "What's the weather like in Beijing?")) + msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeSystem, systemMessage())) + msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeHuman, "What's the weather like in Beijing?")) ctx := context.Background() - for { + for retries := 3; retries > 0; retries = retries - 1 { resp, err := llm.GenerateContent(ctx, msgs) if err != nil { log.Fatal(err) @@ -49,19 +51,23 @@ func main() { if c := unmarshalCall(choice1.Content); c != nil { log.Printf("Call: %v", c.Tool) - + if *flagVerbose { + log.Printf("Call: %v (raw: %v)", c.Tool, choice1.Content) + } msg, cont := dispatchCall(c) if !cont { break } - msgs = append(msgs, msg) } else { // Ollama doesn't always respond with a function call, let it try again. log.Printf("Not a call: %v", choice1.Content) - msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeHuman, "Sorry, I don't understand. Please try again.")) } + + if retries == 0 { + log.Fatal("retries exhausted") + } } } @@ -72,11 +78,9 @@ type Call struct { func unmarshalCall(input string) *Call { var c Call - if err := json.Unmarshal([]byte(input), &c); err == nil && c.Tool != "" { return &c } - return nil } @@ -84,8 +88,7 @@ func dispatchCall(c *Call) (llms.MessageContent, bool) { // ollama doesn't always respond with a *valid* function call. As we're using prompt // engineering to inject the tools, it may hallucinate. if !validTool(c.Tool) { - log.Printf("invalid function call: %#v", c) - + log.Printf("invalid function call: %#v, prompting model to try again", c) return llms.TextParts(llms.ChatMessageTypeHuman, "Tool does not exist, please try again."), true } @@ -106,7 +109,7 @@ func dispatchCall(c *Call) (llms.MessageContent, bool) { if err != nil { log.Fatal(err) } - return llms.TextParts(llms.ChatMessageTypeSystem, weather), true + return llms.TextParts(llms.ChatMessageTypeHuman, weather), true case "finalResponse": resp, ok := c.Input["response"].(string) if !ok { @@ -124,11 +127,9 @@ func dispatchCall(c *Call) (llms.MessageContent, bool) { func validTool(name string) bool { var valid []string - for _, v := range functions { valid = append(valid, v.Name) } - return slices.Contains(valid, name) } diff --git a/llms/ollama/internal/ollamaclient/types.go b/llms/ollama/internal/ollamaclient/types.go index 87d1ecb7d..10a7fb19d 100644 --- a/llms/ollama/internal/ollamaclient/types.go +++ b/llms/ollama/internal/ollamaclient/types.go @@ -49,7 +49,7 @@ type Message struct { type ChatRequest struct { Model string `json:"model"` Messages []*Message `json:"messages"` - Stream *bool `json:"stream,omitempty"` + Stream bool `json:"stream,omitempty"` Format string `json:"format"` KeepAlive string `json:"keep_alive,omitempty"` diff --git a/llms/ollama/ollamallm.go b/llms/ollama/ollamallm.go index 71ea7397f..0de34a599 100644 --- a/llms/ollama/ollamallm.go +++ b/llms/ollama/ollamallm.go @@ -108,7 +108,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten Format: format, Messages: chatMsgs, Options: ollamaOptions, - Stream: func(b bool) *bool { return &b }(opts.StreamingFunc != nil), + Stream: opts.StreamingFunc != nil, } keepAlive := o.options.keepAlive @@ -129,7 +129,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten if response.Message != nil { streamedResponse += response.Message.Content } - if response.Done { + if !req.Stream || response.Done { resp = response resp.Message = &ollamaclient.Message{ Role: "assistant",