Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

embeddings/huggingface: Creating embedder for the Huggingface hub #246

Merged
merged 5 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions embeddings/huggingface/huggingface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package huggingface

import (
"context"
"strings"

"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms/huggingface"
)

// Huggingface is the embedder using the Huggingface hub api.
type Huggingface struct {
client *huggingface.LLM
Model string
Task string

StripNewLines bool
BatchSize int
}

var _ embeddings.Embedder = &Huggingface{}

func NewHuggingface(opts ...Option) (*Huggingface, error) {
v, err := applyOptions(opts...)
if err != nil {
return nil, err
}

return v, nil
}

func (e *Huggingface) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) {
batchedTexts := embeddings.BatchTexts(
embeddings.MaybeRemoveNewLines(texts, e.StripNewLines),
e.BatchSize,
)

emb := make([][]float64, 0, len(texts))
for _, texts := range batchedTexts {
curTextEmbeddings, err := e.client.CreateEmbedding(ctx, texts, e.Model, e.Task)
if err != nil {
return nil, err
}

textLengths := make([]int, 0, len(texts))
for _, text := range texts {
textLengths = append(textLengths, len(text))
}

combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths)
if err != nil {
return nil, err
}

emb = append(emb, combined)
}

return emb, nil
}

func (e *Huggingface) EmbedQuery(ctx context.Context, text string) ([]float64, error) {
if e.StripNewLines {
text = strings.ReplaceAll(text, "\n", " ")
}

emb, err := e.client.CreateEmbedding(ctx, []string{text}, e.Model, e.Task)
if err != nil {
return nil, err
}

return emb[0], nil
}
27 changes: 27 additions & 0 deletions embeddings/huggingface/huggingface_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package huggingface

import (
"context"
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestHuggingfaceEmbeddings(t *testing.T) {
t.Parallel()

if openaiKey := os.Getenv("HUGGINGFACEHUB_API_TOKEN"); openaiKey == "" {
t.Skip("HUGGINGFACEHUB_API_TOKEN not set")
}
e, err := NewHuggingface()
require.NoError(t, err)

_, err = e.EmbedQuery(context.Background(), "Hello world!")
require.NoError(t, err)

embeddings, err := e.EmbedDocuments(context.Background(), []string{"Hello world", "The world is ending", "good bye"})
require.NoError(t, err)
assert.Len(t, embeddings, 3)
}
73 changes: 73 additions & 0 deletions embeddings/huggingface/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package huggingface

import (
"github.com/tmc/langchaingo/llms/huggingface"
)

const (
_defaultBatchSize = 512
_defaultStripNewLines = true
_defaultModel = "sentence-transformers/all-mpnet-base-v2"
_defaultTask = "feature-extraction"
)

// Option is a function type that can be used to modify the client.
type Option func(p *Huggingface)

// WithModel is an option for providing the model name to use.
func WithModel(model string) Option {
return func(p *Huggingface) {
p.Model = model
}
}

// WithTask is an option for providing the task to call the model with.
func WithTask(task string) Option {
return func(p *Huggingface) {
p.Task = task
}
}

// WithClient is an option for providing the LLM client.
func WithClient(client huggingface.LLM) Option {
return func(p *Huggingface) {
p.client = &client
}
}

// WithStripNewLines is an option for specifying the should it strip new lines.
func WithStripNewLines(stripNewLines bool) Option {
return func(p *Huggingface) {
p.StripNewLines = stripNewLines
}
}

// WithBatchSize is an option for specifying the batch size.
func WithBatchSize(batchSize int) Option {
return func(p *Huggingface) {
p.BatchSize = batchSize
}
}

func applyOptions(opts ...Option) (*Huggingface, error) {
o := &Huggingface{
StripNewLines: _defaultStripNewLines,
BatchSize: _defaultBatchSize,
Model: _defaultModel,
Task: _defaultTask,
}

for _, opt := range opts {
opt(o)
}

if o.client == nil {
client, err := huggingface.New()
if err != nil {
return nil, err
}
o.client = client
}

return o, nil
}
31 changes: 29 additions & 2 deletions llms/huggingface/huggingfacellm.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ import (
)

var (
ErrEmptyResponse = errors.New("empty response")
ErrMissingToken = errors.New("missing the Hugging Face API token. Set it in the HUGGINGFACEHUB_API_TOKEN environment variable") //nolint:lll
ErrEmptyResponse = errors.New("empty response")
ErrMissingToken = errors.New("missing the Hugging Face API token. Set it in the HUGGINGFACEHUB_API_TOKEN environment variable") //nolint:lll
ErrUnexpectedResponseLength = errors.New("unexpected length of response")
)

type LLM struct {
Expand Down Expand Up @@ -91,3 +92,29 @@ func New(opts ...Option) (*LLM, error) {
client: c,
}, nil
}

// CreateEmbedding creates embeddings for the given input texts.
func (o *LLM) CreateEmbedding(
ctx context.Context,
inputTexts []string,
model string,
task string,
) ([][]float64, error) {
embeddings, err := o.client.CreateEmbedding(ctx, model, task, &huggingfaceclient.EmbeddingRequest{
Inputs: inputTexts,
Options: map[string]any{
"use_gpu": false,
"wait_for_model": true,
},
})
if err != nil {
return nil, err
}
if len(embeddings) == 0 {
return nil, ErrEmptyResponse
}
if len(inputTexts) != len(embeddings) {
return embeddings, ErrUnexpectedResponseLength
}
return embeddings, nil
}
54 changes: 54 additions & 0 deletions llms/huggingface/internal/huggingfaceclient/embeddings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package huggingfaceclient

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
)

type embeddingPayload struct {
Options map[string]any
Inputs []string `json:"inputs"`
}

// nolint:lll
func (c *Client) createEmbedding(ctx context.Context, model string, task string, payload *embeddingPayload) ([][]float32, error) {
body := map[string]any{
"inputs": payload.Inputs,
}
for key, value := range payload.Options {
body[key] = value
}

payloadBytes, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("marshal payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/pipeline/%s/%s", c.url, task, model), bytes.NewReader(payloadBytes))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.Token)
req.Header.Set("Content-Type", "application/json")

r, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer r.Body.Close()

if r.StatusCode != http.StatusOK {
msg := fmt.Sprintf("API returned unexpected status code: %d", r.StatusCode)

return nil, fmt.Errorf("%s: %s", msg, "unable to create embeddings") // nolint:goerr113
}

var response [][]float32
if err := json.NewDecoder(r.Body).Decode(&response); err != nil {
return nil, fmt.Errorf("decode response: %w", err)
}

return response, nil
}
44 changes: 43 additions & 1 deletion llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ var (
ErrEmptyResponse = errors.New("empty response")
)

const huggingfaceAPIBaseURL = "https://api-inference.huggingface.co"

type Client struct {
Token string
Model string
Expand All @@ -24,7 +26,7 @@ func New(token string, model string) (*Client, error) {
return &Client{
Token: token,
Model: model,
url: hfInferenceAPI,
url: huggingfaceAPIBaseURL,
}, nil
}

Expand Down Expand Up @@ -73,3 +75,43 @@ func (c *Client) RunInference(ctx context.Context, request *InferenceRequest) (*
Text: text,
}, nil
}

// EmbeddingRequest is a request to create an embedding.
type EmbeddingRequest struct {
Options map[string]any `json:"options"`
Inputs []string `json:"inputs"`
}

// CreateEmbedding creates embeddings.
func (c *Client) CreateEmbedding(
ctx context.Context,
model string,
task string,
r *EmbeddingRequest,
) ([][]float64, error) {
resp, err := c.createEmbedding(ctx, model, task, &embeddingPayload{
Inputs: r.Inputs,
Options: r.Options,
})
if err != nil {
return nil, err
}

if len(resp) == 0 {
return nil, ErrEmptyResponse
}

return c.convertFloat32ToFloat64(resp), nil
}

func (c *Client) convertFloat32ToFloat64(input [][]float32) [][]float64 {
output := make([][]float64, len(input))
for i, row := range input {
output[i] = make([]float64, len(row))
for j, val := range row {
output[i][j] = float64(val)
}
}

return output
}
4 changes: 1 addition & 3 deletions llms/huggingface/internal/huggingfaceclient/inference.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,13 @@ type (
}
)

const hfInferenceAPI = "https://api-inference.huggingface.co/models/"

func (c *Client) runInference(ctx context.Context, payload *inferencePayload) (inferenceResponsePayload, error) {
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, err
}
body := bytes.NewReader(payloadBytes)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s%s", c.url, payload.Model), body) //nolint:lll
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/models/%s", c.url, payload.Model), body) //nolint:lll
if err != nil {
return nil, err
}
Expand Down
Loading