From 35391ceb0067e1b667a8bcb81246a3869883dfde Mon Sep 17 00:00:00 2001 From: zivkovicn Date: Thu, 10 Aug 2023 17:41:31 +0200 Subject: [PATCH 1/4] feature-huggingface-embeddings | huggingface embeddings --- embeddings/huggingface/huggingface.go | 72 ++++++++++++++++++ embeddings/huggingface/huggingface_test.go | 27 +++++++ embeddings/huggingface/options.go | 73 +++++++++++++++++++ llms/huggingface/huggingfacellm.go | 26 ++++++- .../internal/huggingfaceclient/embeddings.go | 61 ++++++++++++++++ .../huggingfaceclient/huggingfaceclient.go | 37 +++++++++- .../internal/huggingfaceclient/inference.go | 4 +- 7 files changed, 295 insertions(+), 5 deletions(-) create mode 100644 embeddings/huggingface/huggingface.go create mode 100644 embeddings/huggingface/huggingface_test.go create mode 100644 embeddings/huggingface/options.go create mode 100644 llms/huggingface/internal/huggingfaceclient/embeddings.go diff --git a/embeddings/huggingface/huggingface.go b/embeddings/huggingface/huggingface.go new file mode 100644 index 000000000..1668afe3c --- /dev/null +++ b/embeddings/huggingface/huggingface.go @@ -0,0 +1,72 @@ +package huggingface + +import ( + "context" + "strings" + + "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/llms/huggingface" +) + +// Huggingface is the embedder using the Huggingface hub api. +type Huggingface struct { + client *huggingface.LLM + Model string + Task string + + StripNewLines bool + BatchSize int +} + +var _ embeddings.Embedder = &Huggingface{} + +func NewHuggingface(opts ...Option) (*Huggingface, error) { + v, err := applyOptions(opts...) + if err != nil { + return nil, err + } + + return v, nil +} + +func (e *Huggingface) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) { + batchedTexts := embeddings.BatchTexts( + embeddings.MaybeRemoveNewLines(texts, e.StripNewLines), + e.BatchSize, + ) + + emb := make([][]float64, 0, len(texts)) + for _, texts := range batchedTexts { + curTextEmbeddings, err := e.client.CreateEmbedding(ctx, texts, e.Model, e.Task) + if err != nil { + return nil, err + } + + textLengths := make([]int, 0, len(texts)) + for _, text := range texts { + textLengths = append(textLengths, len(text)) + } + + combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths) + if err != nil { + return nil, err + } + + emb = append(emb, combined) + } + + return emb, nil +} + +func (e *Huggingface) EmbedQuery(ctx context.Context, text string) ([]float64, error) { + if e.StripNewLines { + text = strings.ReplaceAll(text, "\n", " ") + } + + emb, err := e.client.CreateEmbedding(ctx, []string{text}, e.Model, e.Task) + if err != nil { + return nil, err + } + + return emb[0], nil +} diff --git a/embeddings/huggingface/huggingface_test.go b/embeddings/huggingface/huggingface_test.go new file mode 100644 index 000000000..288fc24fc --- /dev/null +++ b/embeddings/huggingface/huggingface_test.go @@ -0,0 +1,27 @@ +package huggingface + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHuggingfaceEmbeddings(t *testing.T) { + t.Parallel() + + if openaiKey := os.Getenv("HUGGINGFACEHUB_API_TOKEN"); openaiKey == "" { + t.Skip("HUGGINGFACEHUB_API_TOKEN not set") + } + e, err := NewHuggingface() + require.NoError(t, err) + + _, err = e.EmbedQuery(context.Background(), "Hello world!") + require.NoError(t, err) + + embeddings, err := e.EmbedDocuments(context.Background(), []string{"Hello world", "The world is ending", "good bye"}) + require.NoError(t, err) + assert.Len(t, embeddings, 3) +} diff --git a/embeddings/huggingface/options.go b/embeddings/huggingface/options.go new file mode 100644 index 000000000..883d998b3 --- /dev/null +++ b/embeddings/huggingface/options.go @@ -0,0 +1,73 @@ +package huggingface + +import ( + "github.com/tmc/langchaingo/llms/huggingface" +) + +const ( + _defaultBatchSize = 512 + _defaultStripNewLines = true + _defaultModel = "sentence-transformers/all-mpnet-base-v2" + _defaultTask = "feature-extraction" +) + +// Option is a function type that can be used to modify the client. +type Option func(p *Huggingface) + +// WithModel is an option for providing the model name to use. +func WithModel(model string) Option { + return func(p *Huggingface) { + p.Model = model + } +} + +// WithTask is an option for providing the task to call the model with. +func WithTask(task string) Option { + return func(p *Huggingface) { + p.Task = task + } +} + +// WithClient is an option for providing the LLM client. +func WithClient(client huggingface.LLM) Option { + return func(p *Huggingface) { + p.client = &client + } +} + +// WithStripNewLines is an option for specifying the should it strip new lines. +func WithStripNewLines(stripNewLines bool) Option { + return func(p *Huggingface) { + p.StripNewLines = stripNewLines + } +} + +// WithBatchSize is an option for specifying the batch size. +func WithBatchSize(batchSize int) Option { + return func(p *Huggingface) { + p.BatchSize = batchSize + } +} + +func applyOptions(opts ...Option) (*Huggingface, error) { + o := &Huggingface{ + StripNewLines: _defaultStripNewLines, + BatchSize: _defaultBatchSize, + Model: _defaultModel, + Task: _defaultTask, + } + + for _, opt := range opts { + opt(o) + } + + if o.client == nil { + client, err := huggingface.New() + if err != nil { + return nil, err + } + o.client = client + } + + return o, nil +} diff --git a/llms/huggingface/huggingfacellm.go b/llms/huggingface/huggingfacellm.go index 53a68b422..eeffdc315 100644 --- a/llms/huggingface/huggingfacellm.go +++ b/llms/huggingface/huggingfacellm.go @@ -11,8 +11,9 @@ import ( ) var ( - ErrEmptyResponse = errors.New("empty response") - ErrMissingToken = errors.New("missing the Hugging Face API token. Set it in the HUGGINGFACEHUB_API_TOKEN environment variable") //nolint:lll + ErrEmptyResponse = errors.New("empty response") + ErrMissingToken = errors.New("missing the Hugging Face API token. Set it in the HUGGINGFACEHUB_API_TOKEN environment variable") //nolint:lll + ErrUnexpectedResponseLength = errors.New("unexpected length of response") ) type LLM struct { @@ -91,3 +92,24 @@ func New(opts ...Option) (*LLM, error) { client: c, }, nil } + +// CreateEmbedding creates embeddings for the given input texts. +func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string, model string, task string) ([][]float64, error) { + embeddings, err := o.client.CreateEmbedding(ctx, model, task, &huggingfaceclient.EmbeddingRequest{ + Inputs: inputTexts, + Options: map[string]any{ + "use_gpu": false, + "wait_for_model": true, + }, + }) + if err != nil { + return nil, err + } + if len(embeddings) == 0 { + return nil, ErrEmptyResponse + } + if len(inputTexts) != len(embeddings) { + return embeddings, ErrUnexpectedResponseLength + } + return embeddings, nil +} diff --git a/llms/huggingface/internal/huggingfaceclient/embeddings.go b/llms/huggingface/internal/huggingfaceclient/embeddings.go new file mode 100644 index 000000000..ecfe4a8c2 --- /dev/null +++ b/llms/huggingface/internal/huggingfaceclient/embeddings.go @@ -0,0 +1,61 @@ +package huggingfaceclient + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" +) + +type embeddingPayload struct { + Options map[string]any + Inputs []string `json:"inputs"` +} + +// nolint:lll +func (c *Client) createEmbedding(ctx context.Context, model string, task string, payload *embeddingPayload) ([][]float32, error) { + body := map[string]any{ + "inputs": payload.Inputs, + } + for key, value := range payload.Options { + body[key] = value + } + + payloadBytes, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("marshal payload: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/pipeline/%s/%s", c.url, task, model), bytes.NewReader(payloadBytes)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+c.Token) + req.Header.Set("Content-Type", "application/json") + + r, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer r.Body.Close() + + if r.StatusCode != http.StatusOK { + msg := fmt.Sprintf("API returned unexpected status code: %d", r.StatusCode) + + //// No need to check the error here: if it fails, we'll just return the + //// status code. + //var errResp errorMessage + //if err := json.NewDecoder(r.Body).Decode(&errResp); err != nil { + // return nil, errors.New(msg) // nolint:goerr113 + //} + + return nil, fmt.Errorf("%s: %s", msg, "TODO message from error") // nolint:goerr113 + } + + var response [][]float32 + if err := json.NewDecoder(r.Body).Decode(&response); err != nil { + return nil, fmt.Errorf("decode response: %w", err) + } + + return response, nil +} diff --git a/llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go b/llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go index 53941cc5d..68fdb46df 100644 --- a/llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go +++ b/llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go @@ -24,7 +24,7 @@ func New(token string, model string) (*Client, error) { return &Client{ Token: token, Model: model, - url: hfInferenceAPI, + url: huggingfaceAPIBaseURL, }, nil } @@ -73,3 +73,38 @@ func (c *Client) RunInference(ctx context.Context, request *InferenceRequest) (* Text: text, }, nil } + +// EmbeddingRequest is a request to create an embedding. +type EmbeddingRequest struct { + Options map[string]any `json:"options"` + Inputs []string `json:"inputs"` +} + +// CreateEmbedding creates embeddings. +func (c *Client) CreateEmbedding(ctx context.Context, model string, task string, r *EmbeddingRequest) ([][]float64, error) { + resp, err := c.createEmbedding(ctx, model, task, &embeddingPayload{ + Inputs: r.Inputs, + Options: r.Options, + }) + if err != nil { + return nil, err + } + + if len(resp) == 0 { + return nil, ErrEmptyResponse + } + + return c.convertFloat32ToFloat64(resp), nil +} + +func (c *Client) convertFloat32ToFloat64(input [][]float32) [][]float64 { + output := make([][]float64, len(input)) + for i, row := range input { + output[i] = make([]float64, len(row)) + for j, val := range row { + output[i][j] = float64(val) + } + } + + return output +} diff --git a/llms/huggingface/internal/huggingfaceclient/inference.go b/llms/huggingface/internal/huggingfaceclient/inference.go index be4a000b8..bcf49535d 100644 --- a/llms/huggingface/internal/huggingfaceclient/inference.go +++ b/llms/huggingface/internal/huggingfaceclient/inference.go @@ -43,7 +43,7 @@ type ( } ) -const hfInferenceAPI = "https://api-inference.huggingface.co/models/" +const huggingfaceAPIBaseURL = "https://api-inference.huggingface.co" func (c *Client) runInference(ctx context.Context, payload *inferencePayload) (inferenceResponsePayload, error) { payloadBytes, err := json.Marshal(payload) @@ -51,7 +51,7 @@ func (c *Client) runInference(ctx context.Context, payload *inferencePayload) (i return nil, err } body := bytes.NewReader(payloadBytes) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s%s", c.url, payload.Model), body) //nolint:lll + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/models/%s", c.url, payload.Model), body) //nolint:lll if err != nil { return nil, err } From 271cd31daa1c12d93613f51b72474c2d17ef7c36 Mon Sep 17 00:00:00 2001 From: zivkovicn Date: Thu, 10 Aug 2023 17:43:20 +0200 Subject: [PATCH 2/4] feature-huggingface-embeddings | cp --- .../huggingface/internal/huggingfaceclient/embeddings.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/llms/huggingface/internal/huggingfaceclient/embeddings.go b/llms/huggingface/internal/huggingfaceclient/embeddings.go index ecfe4a8c2..b698a237a 100644 --- a/llms/huggingface/internal/huggingfaceclient/embeddings.go +++ b/llms/huggingface/internal/huggingfaceclient/embeddings.go @@ -42,14 +42,7 @@ func (c *Client) createEmbedding(ctx context.Context, model string, task string, if r.StatusCode != http.StatusOK { msg := fmt.Sprintf("API returned unexpected status code: %d", r.StatusCode) - //// No need to check the error here: if it fails, we'll just return the - //// status code. - //var errResp errorMessage - //if err := json.NewDecoder(r.Body).Decode(&errResp); err != nil { - // return nil, errors.New(msg) // nolint:goerr113 - //} - - return nil, fmt.Errorf("%s: %s", msg, "TODO message from error") // nolint:goerr113 + return nil, fmt.Errorf("%s: %s", msg, "unable to create embeddings") // nolint:goerr113 } var response [][]float32 From 934038891ce82901972fde3475fb8551e036df98 Mon Sep 17 00:00:00 2001 From: zivkovicn Date: Thu, 10 Aug 2023 17:45:06 +0200 Subject: [PATCH 3/4] feature-huggingface-embeddings | cp --- .../huggingface/internal/huggingfaceclient/huggingfaceclient.go | 2 ++ llms/huggingface/internal/huggingfaceclient/inference.go | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go b/llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go index 68fdb46df..25f28ce86 100644 --- a/llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go +++ b/llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go @@ -11,6 +11,8 @@ var ( ErrEmptyResponse = errors.New("empty response") ) +const huggingfaceAPIBaseURL = "https://api-inference.huggingface.co" + type Client struct { Token string Model string diff --git a/llms/huggingface/internal/huggingfaceclient/inference.go b/llms/huggingface/internal/huggingfaceclient/inference.go index bcf49535d..6c00ec475 100644 --- a/llms/huggingface/internal/huggingfaceclient/inference.go +++ b/llms/huggingface/internal/huggingfaceclient/inference.go @@ -43,8 +43,6 @@ type ( } ) -const huggingfaceAPIBaseURL = "https://api-inference.huggingface.co" - func (c *Client) runInference(ctx context.Context, payload *inferencePayload) (inferenceResponsePayload, error) { payloadBytes, err := json.Marshal(payload) if err != nil { From 7ec2d7d35ab5da6b5b3ba55a9f6c9998e562b23e Mon Sep 17 00:00:00 2001 From: zivkovicn Date: Thu, 10 Aug 2023 17:51:04 +0200 Subject: [PATCH 4/4] feature-huggingface-embeddings | lint --- llms/huggingface/huggingfacellm.go | 7 ++++++- .../internal/huggingfaceclient/huggingfaceclient.go | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/llms/huggingface/huggingfacellm.go b/llms/huggingface/huggingfacellm.go index eeffdc315..f42c8b818 100644 --- a/llms/huggingface/huggingfacellm.go +++ b/llms/huggingface/huggingfacellm.go @@ -94,7 +94,12 @@ func New(opts ...Option) (*LLM, error) { } // CreateEmbedding creates embeddings for the given input texts. -func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string, model string, task string) ([][]float64, error) { +func (o *LLM) CreateEmbedding( + ctx context.Context, + inputTexts []string, + model string, + task string, +) ([][]float64, error) { embeddings, err := o.client.CreateEmbedding(ctx, model, task, &huggingfaceclient.EmbeddingRequest{ Inputs: inputTexts, Options: map[string]any{ diff --git a/llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go b/llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go index 25f28ce86..81c6aabe1 100644 --- a/llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go +++ b/llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go @@ -83,7 +83,12 @@ type EmbeddingRequest struct { } // CreateEmbedding creates embeddings. -func (c *Client) CreateEmbedding(ctx context.Context, model string, task string, r *EmbeddingRequest) ([][]float64, error) { +func (c *Client) CreateEmbedding( + ctx context.Context, + model string, + task string, + r *EmbeddingRequest, +) ([][]float64, error) { resp, err := c.createEmbedding(ctx, model, task, &embeddingPayload{ Inputs: r.Inputs, Options: r.Options,