diff --git a/embeddings/ernie/ernie.go b/embeddings/ernie/ernie.go index b2af095ad..eba18433d 100644 --- a/embeddings/ernie/ernie.go +++ b/embeddings/ernie/ernie.go @@ -9,28 +9,41 @@ import ( // Ernie https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu type Ernie struct { - client *ernie.LLM + client *ernie.LLM + batchSize int + stripNewLines bool } var _ embeddings.Embedder = &Ernie{} -// todo: use option pass, more: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu#body%E5%8F%82%E6%95%B0 -const batchSize = 16 +// NewErnie creates a new Ernie with options. Options for client, strip new lines and batch size. +func NewErnie(opts ...Option) (*Ernie, error) { + v := &Ernie{ + stripNewLines: defaultStripNewLines, + batchSize: defaultBatchSize, + } -func NewErnie() (*Ernie, error) { - llm, e := ernie.New() - if e != nil { - return nil, e + for _, opt := range opts { + opt(v) } - return &Ernie{client: llm}, nil + + if v.client == nil { + client, err := ernie.New() + if err != nil { + return nil, err + } + v.client = client + } + + return v, nil } // EmbedDocuments implements embeddings.Embedder . // simple impl. func (e *Ernie) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) { batchedTexts := embeddings.BatchTexts( - embeddings.MaybeRemoveNewLines(texts, true), - batchSize, + embeddings.MaybeRemoveNewLines(texts, e.stripNewLines), + e.batchSize, ) emb := make([][]float64, 0, len(texts)) diff --git a/embeddings/ernie/options.go b/embeddings/ernie/options.go new file mode 100644 index 000000000..397804156 --- /dev/null +++ b/embeddings/ernie/options.go @@ -0,0 +1,33 @@ +package ernie + +import "github.com/tmc/langchaingo/llms/ernie" + +const ( + // see: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu#body%E5%8F%82%E6%95%B0 + defaultBatchSize = 16 + defaultStripNewLines = true +) + +// Option is a function type that can be used to modify the client. +type Option func(p *Ernie) + +// WithClient is an option for providing the LLM client. +func WithClient(client ernie.LLM) Option { + return func(e *Ernie) { + e.client = &client + } +} + +// WithBatchSize is an option for specifying the batch size. +func WithBatchSize(batchSize int) Option { + return func(e *Ernie) { + e.batchSize = batchSize + } +} + +// WithStripNewLines is an option for specifying the should it strip new lines. +func WithStripNewLines(stripNewLines bool) Option { + return func(e *Ernie) { + e.stripNewLines = stripNewLines + } +}