diff --git a/embeddings/vertexai/vertexai_palm_test.go b/embeddings/vertexai/vertexai_palm_test.go index 882522c28..dc3a39584 100644 --- a/embeddings/vertexai/vertexai_palm_test.go +++ b/embeddings/vertexai/vertexai_palm_test.go @@ -37,7 +37,7 @@ func TestVertexAIPaLMEmbeddingsWithOptions(t *testing.T) { client, err := vertexai.New() require.NoError(t, err) - e, err := NewVertexAIPaLM(WithClient(*client), WithBatchSize(1), WithStripNewLines(false)) + e, err := NewVertexAIPaLM(WithClient(*client), WithBatchSize(5), WithStripNewLines(false)) require.NoError(t, err) _, err = e.EmbedQuery(context.Background(), "Hello world!") diff --git a/embeddings/vertexai/vertexaichat/options.go b/embeddings/vertexai/vertexaichat/options.go new file mode 100644 index 000000000..c49860a77 --- /dev/null +++ b/embeddings/vertexai/vertexaichat/options.go @@ -0,0 +1,54 @@ +package vertexaichat + +import ( + "github.com/tmc/langchaingo/llms/vertexai" +) + +const ( + _defaultBatchSize = 512 + _defaultStripNewLines = true +) + +type ChatOption func(p *ChatVertexAI) + +// WithClient is an option for providing the LLM client. +func WithClient(client vertexai.Chat) ChatOption { + return func(p *ChatVertexAI) { + p.client = &client + } +} + +// WithStripNewLines is an option for specifying the should it strip new lines. +func WithStripNewLines(stripNewLines bool) ChatOption { + return func(p *ChatVertexAI) { + p.StripNewLines = stripNewLines + } +} + +// WithBatchSize is an option for specifying the batch size. +func WithBatchSize(batchSize int) ChatOption { + return func(p *ChatVertexAI) { + p.BatchSize = batchSize + } +} + +func applyChatClientOptions(opts ...ChatOption) (ChatVertexAI, error) { + o := &ChatVertexAI{ + StripNewLines: _defaultStripNewLines, + BatchSize: _defaultBatchSize, + } + + for _, opt := range opts { + opt(o) + } + + if o.client == nil { + client, err := vertexai.NewChat() + if err != nil { + return ChatVertexAI{}, err + } + o.client = client + } + + return *o, nil +} diff --git a/embeddings/vertexai/vertexaichat/vertexai_chat.go b/embeddings/vertexai/vertexaichat/vertexai_chat.go new file mode 100644 index 000000000..76b24f493 --- /dev/null +++ b/embeddings/vertexai/vertexaichat/vertexai_chat.go @@ -0,0 +1,71 @@ +package vertexaichat + +import ( + "context" + "strings" + + "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/llms/vertexai" +) + +// ChatVertexAI is the embedder using the VertexAI api. +type ChatVertexAI struct { + client *vertexai.Chat + + StripNewLines bool + BatchSize int +} + +var _ embeddings.Embedder = ChatVertexAI{} + +// NewChatVertexAI creates a new ChatVertexAI with options. Options for client, strip new lines and batch. +func NewChatVertexAI(opts ...ChatOption) (ChatVertexAI, error) { + o, err := applyChatClientOptions(opts...) + if err != nil { + return ChatVertexAI{}, err + } + + return o, nil +} + +func (e ChatVertexAI) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) { + batchedTexts := embeddings.BatchTexts( + embeddings.MaybeRemoveNewLines(texts, e.StripNewLines), + e.BatchSize, + ) + + emb := make([][]float64, 0, len(texts)) + for _, texts := range batchedTexts { + curTextEmbeddings, err := e.client.CreateEmbedding(ctx, texts) + if err != nil { + return nil, err + } + + textLengths := make([]int, 0, len(texts)) + for _, text := range texts { + textLengths = append(textLengths, len(text)) + } + + combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths) + if err != nil { + return nil, err + } + + emb = append(emb, combined) + } + + return emb, nil +} + +func (e ChatVertexAI) EmbedQuery(ctx context.Context, text string) ([]float64, error) { + if e.StripNewLines { + text = strings.ReplaceAll(text, "\n", " ") + } + + emb, err := e.client.CreateEmbedding(ctx, []string{text}) + if err != nil { + return nil, err + } + + return emb[0], nil +} diff --git a/embeddings/vertexai/vertexaichat/vertexai_chat_test.go b/embeddings/vertexai/vertexaichat/vertexai_chat_test.go new file mode 100644 index 000000000..3cbb671c6 --- /dev/null +++ b/embeddings/vertexai/vertexaichat/vertexai_chat_test.go @@ -0,0 +1,50 @@ +package vertexaichat + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tmc/langchaingo/llms/vertexai" +) + +func TestVertexAIChatEmbeddings(t *testing.T) { + t.Parallel() + + if gcpProjectID := os.Getenv("GOOGLE_CLOUD_PROJECT"); gcpProjectID == "" { + t.Skip("GOOGLE_CLOUD_PROJECT not set") + } + + e, err := NewChatVertexAI() + require.NoError(t, err) + + _, err = e.EmbedQuery(context.Background(), "Hello world!") + require.NoError(t, err) + + embeddings, err := e.EmbedDocuments(context.Background(), []string{"Hello world", "The world is ending", "good bye"}) + require.NoError(t, err) + assert.Len(t, embeddings, 3) +} + +func TestVertexAIChatEmbeddingsWithOptions(t *testing.T) { + t.Parallel() + + if gcpProjectID := os.Getenv("GOOGLE_CLOUD_PROJECT"); gcpProjectID == "" { + t.Skip("GOOGLE_CLOUD_PROJECT not set") + } + + client, err := vertexai.NewChat() + require.NoError(t, err) + + e, err := NewChatVertexAI(WithClient(*client), WithBatchSize(5), WithStripNewLines(false)) + require.NoError(t, err) + + _, err = e.EmbedQuery(context.Background(), "Hello world!") + require.NoError(t, err) + + embeddings, err := e.EmbedDocuments(context.Background(), []string{"Hello world"}) + require.NoError(t, err) + assert.Len(t, embeddings, 1) +} diff --git a/llms/vertexai/vertexai_palm_llm.go b/llms/vertexai/vertexai_palm_llm.go index ea7709e98..e837398e2 100644 --- a/llms/vertexai/vertexai_palm_llm.go +++ b/llms/vertexai/vertexai_palm_llm.go @@ -16,11 +16,6 @@ var ( ErrNotImplemented = errors.New("not implemented") ) -const ( - userAuthor = "user" - botAuthor = "bot" -) - type LLM struct { client *vertexaiclient.PaLMClient } @@ -92,110 +87,12 @@ func (o *LLM) GetNumTokens(text string) int { return llms.CountTokens(vertexaiclient.TextModelName, text) } -type ChatMessage = vertexaiclient.ChatMessage - -type Chat struct { - client *vertexaiclient.PaLMClient -} - -var ( - _ llms.ChatLLM = (*Chat)(nil) - _ llms.LanguageModel = (*Chat)(nil) -) - -// Chat requests a chat response for the given messages. -func (o *Chat) Call(ctx context.Context, messages []schema.ChatMessage, options ...llms.CallOption) (*schema.AIChatMessage, error) { // nolint: lll - r, err := o.Generate(ctx, [][]schema.ChatMessage{messages}, options...) - if err != nil { - return nil, err - } - if len(r) == 0 { - return nil, ErrEmptyResponse - } - return r[0].Message, nil -} - -// Generate requests a chat response for each of the sets of messages. -func (o *Chat) Generate(ctx context.Context, messageSets [][]schema.ChatMessage, options ...llms.CallOption) ([]*llms.Generation, error) { // nolint: lll - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - if opts.StreamingFunc != nil { - return nil, ErrNotImplemented - } - - generations := make([]*llms.Generation, 0, len(messageSets)) - for _, messages := range messageSets { - msgs := toClientChatMessage(messages) - result, err := o.client.CreateChat(ctx, &vertexaiclient.ChatRequest{ - Temperature: opts.Temperature, - Messages: msgs, - }) - if err != nil { - return nil, err - } - if len(result.Candidates) == 0 { - return nil, ErrEmptyResponse - } - generations = append(generations, &llms.Generation{ - Message: &schema.AIChatMessage{ - Content: result.Candidates[0].Content, - }, - Text: result.Candidates[0].Content, - }) - } - - return generations, nil -} - -func (o *Chat) GeneratePrompt(ctx context.Context, promptValues []schema.PromptValue, options ...llms.CallOption) (llms.LLMResult, error) { //nolint:lll - return llms.GenerateChatPrompt(ctx, o, promptValues, options...) -} - -func (o *Chat) GetNumTokens(text string) int { - return llms.CountTokens(vertexaiclient.TextModelName, text) -} - -func toClientChatMessage(messages []schema.ChatMessage) []*vertexaiclient.ChatMessage { - msgs := make([]*vertexaiclient.ChatMessage, len(messages)) - for i, m := range messages { - msg := &vertexaiclient.ChatMessage{ - Content: m.GetContent(), - } - typ := m.GetType() - switch typ { - case schema.ChatMessageTypeSystem: - msg.Author = botAuthor - case schema.ChatMessageTypeAI: - msg.Author = botAuthor - case schema.ChatMessageTypeHuman: - msg.Author = userAuthor - case schema.ChatMessageTypeGeneric: - msg.Author = userAuthor - case schema.ChatMessageTypeFunction: - msg.Author = userAuthor - } - if n, ok := m.(schema.Named); ok { - msg.Author = n.GetName() - } - msgs[i] = msg - } - return msgs -} - // New returns a new VertexAI PaLM LLM. func New(opts ...Option) (*LLM, error) { client, err := newClient(opts...) return &LLM{client: client}, err } -// New returns a new VertexAI PaLM Chat LLM. -func NewChat(opts ...Option) (*Chat, error) { - client, err := newClient(opts...) - return &Chat{client: client}, err -} - func newClient(opts ...Option) (*vertexaiclient.PaLMClient, error) { // Ensure options are initialized only once. initOptions.Do(initOpts) diff --git a/llms/vertexai/vertexai_palm_llm_chat.go b/llms/vertexai/vertexai_palm_llm_chat.go new file mode 100644 index 000000000..9504e5752 --- /dev/null +++ b/llms/vertexai/vertexai_palm_llm_chat.go @@ -0,0 +1,129 @@ +package vertexai + +import ( + "context" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/vertexai/internal/vertexaiclient" + "github.com/tmc/langchaingo/schema" +) + +const ( + userAuthor = "user" + botAuthor = "bot" +) + +type ChatMessage = vertexaiclient.ChatMessage + +type Chat struct { + client *vertexaiclient.PaLMClient +} + +var ( + _ llms.ChatLLM = (*Chat)(nil) + _ llms.LanguageModel = (*Chat)(nil) +) + +// Chat requests a chat response for the given messages. +func (o *Chat) Call(ctx context.Context, messages []schema.ChatMessage, options ...llms.CallOption) (*schema.AIChatMessage, error) { // nolint: lll + r, err := o.Generate(ctx, [][]schema.ChatMessage{messages}, options...) + if err != nil { + return nil, err + } + if len(r) == 0 { + return nil, ErrEmptyResponse + } + return r[0].Message, nil +} + +// Generate requests a chat response for each of the sets of messages. +func (o *Chat) Generate(ctx context.Context, messageSets [][]schema.ChatMessage, options ...llms.CallOption) ([]*llms.Generation, error) { // nolint: lll + opts := llms.CallOptions{} + for _, opt := range options { + opt(&opts) + } + if opts.StreamingFunc != nil { + return nil, ErrNotImplemented + } + + generations := make([]*llms.Generation, 0, len(messageSets)) + for _, messages := range messageSets { + msgs := toClientChatMessage(messages) + result, err := o.client.CreateChat(ctx, &vertexaiclient.ChatRequest{ + Temperature: opts.Temperature, + Messages: msgs, + }) + if err != nil { + return nil, err + } + if len(result.Candidates) == 0 { + return nil, ErrEmptyResponse + } + generations = append(generations, &llms.Generation{ + Message: &schema.AIChatMessage{ + Content: result.Candidates[0].Content, + }, + Text: result.Candidates[0].Content, + }) + } + + return generations, nil +} + +func (o *Chat) GeneratePrompt(ctx context.Context, promptValues []schema.PromptValue, options ...llms.CallOption) (llms.LLMResult, error) { //nolint:lll + return llms.GenerateChatPrompt(ctx, o, promptValues, options...) +} + +func (o *Chat) GetNumTokens(text string) int { + return llms.CountTokens(vertexaiclient.TextModelName, text) +} + +func toClientChatMessage(messages []schema.ChatMessage) []*vertexaiclient.ChatMessage { + msgs := make([]*vertexaiclient.ChatMessage, len(messages)) + for i, m := range messages { + msg := &vertexaiclient.ChatMessage{ + Content: m.GetContent(), + } + typ := m.GetType() + switch typ { + case schema.ChatMessageTypeSystem: + msg.Author = botAuthor + case schema.ChatMessageTypeAI: + msg.Author = botAuthor + case schema.ChatMessageTypeHuman: + msg.Author = userAuthor + case schema.ChatMessageTypeGeneric: + msg.Author = userAuthor + case schema.ChatMessageTypeFunction: + msg.Author = userAuthor + } + if n, ok := m.(schema.Named); ok { + msg.Author = n.GetName() + } + msgs[i] = msg + } + return msgs +} + +// NewChat returns a new VertexAI PaLM Chat LLM. +func NewChat(opts ...Option) (*Chat, error) { + client, err := newClient(opts...) + return &Chat{client: client}, err +} + +// CreateEmbedding creates embeddings for the given input texts. +func (o *Chat) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float64, error) { + embeddings, err := o.client.CreateEmbedding(ctx, &vertexaiclient.EmbeddingRequest{ + Input: inputTexts, + }) + if err != nil { + return nil, err + } + if len(embeddings) == 0 { + return nil, ErrEmptyResponse + } + if len(inputTexts) != len(embeddings) { + return embeddings, ErrUnexpectedResponseLength + } + return embeddings, nil +} diff --git a/textsplitter/split_documents.go b/textsplitter/split_documents.go index 654641243..674ea7ec9 100644 --- a/textsplitter/split_documents.go +++ b/textsplitter/split_documents.go @@ -68,7 +68,7 @@ func joinDocs(docs []string, separator string) string { } // mergeSplits merges smaller splits into splits that are closer to the chunkSize. -func mergeSplits(splits []string, separator string, chunkSize int, chunkOverlap int) []string { +func mergeSplits(splits []string, separator string, chunkSize int, chunkOverlap int) []string { //nolint:cyclop docs := make([]string, 0) currentDoc := make([]string, 0) total := 0 @@ -87,12 +87,11 @@ func mergeSplits(splits []string, separator string, chunkSize int, chunkOverlap } for shouldPop(chunkOverlap, chunkSize, total, len(split), len(separator), len(currentDoc)) { - total -= len(currentDoc[0]) + total -= len(currentDoc[0]) //nolint:gosec if len(currentDoc) > 1 { total -= len(separator) } - - currentDoc = currentDoc[1:] + currentDoc = currentDoc[1:] //nolint:gosec } }