Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

embeddings & llm: Creating the actual embeddings for the VertexAI chat, splitting llms for VertexAI #262

Merged
merged 6 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion embeddings/vertexai/vertexai_palm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TestVertexAIPaLMEmbeddingsWithOptions(t *testing.T) {
client, err := vertexai.New()
require.NoError(t, err)

e, err := NewVertexAIPaLM(WithClient(*client), WithBatchSize(1), WithStripNewLines(false))
e, err := NewVertexAIPaLM(WithClient(*client), WithBatchSize(5), WithStripNewLines(false))
require.NoError(t, err)

_, err = e.EmbedQuery(context.Background(), "Hello world!")
Expand Down
54 changes: 54 additions & 0 deletions embeddings/vertexai/vertexaichat/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package vertexaichat

import (
"github.com/tmc/langchaingo/llms/vertexai"
)

const (
_defaultBatchSize = 512
_defaultStripNewLines = true
)

type ChatOption func(p *ChatVertexAI)

// WithClient is an option for providing the LLM client.
func WithClient(client vertexai.Chat) ChatOption {
return func(p *ChatVertexAI) {
p.client = &client
}
}

// WithStripNewLines is an option for specifying the should it strip new lines.
func WithStripNewLines(stripNewLines bool) ChatOption {
return func(p *ChatVertexAI) {
p.StripNewLines = stripNewLines
}
}

// WithBatchSize is an option for specifying the batch size.
func WithBatchSize(batchSize int) ChatOption {
return func(p *ChatVertexAI) {
p.BatchSize = batchSize
}
}

func applyChatClientOptions(opts ...ChatOption) (ChatVertexAI, error) {
o := &ChatVertexAI{
StripNewLines: _defaultStripNewLines,
BatchSize: _defaultBatchSize,
}

for _, opt := range opts {
opt(o)
}

if o.client == nil {
client, err := vertexai.NewChat()
if err != nil {
return ChatVertexAI{}, err
}
o.client = client
}

return *o, nil
}
71 changes: 71 additions & 0 deletions embeddings/vertexai/vertexaichat/vertexai_chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package vertexaichat

import (
"context"
"strings"

"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms/vertexai"
)

// ChatVertexAI is the embedder using the VertexAI api.
type ChatVertexAI struct {
client *vertexai.Chat

StripNewLines bool
BatchSize int
}

var _ embeddings.Embedder = ChatVertexAI{}

// NewChatVertexAI creates a new ChatVertexAI with options. Options for client, strip new lines and batch.
func NewChatVertexAI(opts ...ChatOption) (ChatVertexAI, error) {
o, err := applyChatClientOptions(opts...)
if err != nil {
return ChatVertexAI{}, err
}

return o, nil
}

func (e ChatVertexAI) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) {
batchedTexts := embeddings.BatchTexts(
embeddings.MaybeRemoveNewLines(texts, e.StripNewLines),
e.BatchSize,
)

emb := make([][]float64, 0, len(texts))
for _, texts := range batchedTexts {
curTextEmbeddings, err := e.client.CreateEmbedding(ctx, texts)
if err != nil {
return nil, err
}

textLengths := make([]int, 0, len(texts))
for _, text := range texts {
textLengths = append(textLengths, len(text))
}

combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths)
if err != nil {
return nil, err
}

emb = append(emb, combined)
}

return emb, nil
}

func (e ChatVertexAI) EmbedQuery(ctx context.Context, text string) ([]float64, error) {
if e.StripNewLines {
text = strings.ReplaceAll(text, "\n", " ")
}

emb, err := e.client.CreateEmbedding(ctx, []string{text})
if err != nil {
return nil, err
}

return emb[0], nil
}
50 changes: 50 additions & 0 deletions embeddings/vertexai/vertexaichat/vertexai_chat_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package vertexaichat

import (
"context"
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/llms/vertexai"
)

func TestVertexAIChatEmbeddings(t *testing.T) {
t.Parallel()

if gcpProjectID := os.Getenv("GOOGLE_CLOUD_PROJECT"); gcpProjectID == "" {
t.Skip("GOOGLE_CLOUD_PROJECT not set")
}

e, err := NewChatVertexAI()
require.NoError(t, err)

_, err = e.EmbedQuery(context.Background(), "Hello world!")
require.NoError(t, err)

embeddings, err := e.EmbedDocuments(context.Background(), []string{"Hello world", "The world is ending", "good bye"})
require.NoError(t, err)
assert.Len(t, embeddings, 3)
}

func TestVertexAIChatEmbeddingsWithOptions(t *testing.T) {
t.Parallel()

if gcpProjectID := os.Getenv("GOOGLE_CLOUD_PROJECT"); gcpProjectID == "" {
t.Skip("GOOGLE_CLOUD_PROJECT not set")
}

client, err := vertexai.NewChat()
require.NoError(t, err)

e, err := NewChatVertexAI(WithClient(*client), WithBatchSize(5), WithStripNewLines(false))
require.NoError(t, err)

_, err = e.EmbedQuery(context.Background(), "Hello world!")
require.NoError(t, err)

embeddings, err := e.EmbedDocuments(context.Background(), []string{"Hello world"})
require.NoError(t, err)
assert.Len(t, embeddings, 1)
}
103 changes: 0 additions & 103 deletions llms/vertexai/vertexai_palm_llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ var (
ErrNotImplemented = errors.New("not implemented")
)

const (
userAuthor = "user"
botAuthor = "bot"
)

type LLM struct {
client *vertexaiclient.PaLMClient
}
Expand Down Expand Up @@ -92,110 +87,12 @@ func (o *LLM) GetNumTokens(text string) int {
return llms.CountTokens(vertexaiclient.TextModelName, text)
}

type ChatMessage = vertexaiclient.ChatMessage

type Chat struct {
client *vertexaiclient.PaLMClient
}

var (
_ llms.ChatLLM = (*Chat)(nil)
_ llms.LanguageModel = (*Chat)(nil)
)

// Chat requests a chat response for the given messages.
func (o *Chat) Call(ctx context.Context, messages []schema.ChatMessage, options ...llms.CallOption) (*schema.AIChatMessage, error) { // nolint: lll
r, err := o.Generate(ctx, [][]schema.ChatMessage{messages}, options...)
if err != nil {
return nil, err
}
if len(r) == 0 {
return nil, ErrEmptyResponse
}
return r[0].Message, nil
}

// Generate requests a chat response for each of the sets of messages.
func (o *Chat) Generate(ctx context.Context, messageSets [][]schema.ChatMessage, options ...llms.CallOption) ([]*llms.Generation, error) { // nolint: lll
opts := llms.CallOptions{}
for _, opt := range options {
opt(&opts)
}
if opts.StreamingFunc != nil {
return nil, ErrNotImplemented
}

generations := make([]*llms.Generation, 0, len(messageSets))
for _, messages := range messageSets {
msgs := toClientChatMessage(messages)
result, err := o.client.CreateChat(ctx, &vertexaiclient.ChatRequest{
Temperature: opts.Temperature,
Messages: msgs,
})
if err != nil {
return nil, err
}
if len(result.Candidates) == 0 {
return nil, ErrEmptyResponse
}
generations = append(generations, &llms.Generation{
Message: &schema.AIChatMessage{
Content: result.Candidates[0].Content,
},
Text: result.Candidates[0].Content,
})
}

return generations, nil
}

func (o *Chat) GeneratePrompt(ctx context.Context, promptValues []schema.PromptValue, options ...llms.CallOption) (llms.LLMResult, error) { //nolint:lll
return llms.GenerateChatPrompt(ctx, o, promptValues, options...)
}

func (o *Chat) GetNumTokens(text string) int {
return llms.CountTokens(vertexaiclient.TextModelName, text)
}

func toClientChatMessage(messages []schema.ChatMessage) []*vertexaiclient.ChatMessage {
msgs := make([]*vertexaiclient.ChatMessage, len(messages))
for i, m := range messages {
msg := &vertexaiclient.ChatMessage{
Content: m.GetContent(),
}
typ := m.GetType()
switch typ {
case schema.ChatMessageTypeSystem:
msg.Author = botAuthor
case schema.ChatMessageTypeAI:
msg.Author = botAuthor
case schema.ChatMessageTypeHuman:
msg.Author = userAuthor
case schema.ChatMessageTypeGeneric:
msg.Author = userAuthor
case schema.ChatMessageTypeFunction:
msg.Author = userAuthor
}
if n, ok := m.(schema.Named); ok {
msg.Author = n.GetName()
}
msgs[i] = msg
}
return msgs
}

// New returns a new VertexAI PaLM LLM.
func New(opts ...Option) (*LLM, error) {
client, err := newClient(opts...)
return &LLM{client: client}, err
}

// New returns a new VertexAI PaLM Chat LLM.
func NewChat(opts ...Option) (*Chat, error) {
client, err := newClient(opts...)
return &Chat{client: client}, err
}

func newClient(opts ...Option) (*vertexaiclient.PaLMClient, error) {
// Ensure options are initialized only once.
initOptions.Do(initOpts)
Expand Down
Loading
Loading