Skip to content

Commit

Permalink
fix/AuzreOpenAI embeddings and LLM deployment name (#172)
Browse files Browse the repository at this point in the history
solves "The API deployment for this resource does not exist" for LLM and
embedding models deployed in Azure OpenAI by deployment name supported
in tmc/langchaingo#253

We can't Validate OpenAI LLM model names from hard-coded list in Azure
because the model name parameter in API request is a deployment name,
and while Microsoft advises us to use the model name as deployment name,
we did not listen, and I didn't want to coordinate redeploying with a
different name on a Friday.

This also permits use of customized models that can be deployed in Azure
side-by-side base models as added benefit so I think it was worthwhile.

Co-authored-by: Claudia Justice Kane <[email protected]>
  • Loading branch information
danielchalef and ClaudiaJ authored Aug 28, 2023
1 parent 8eb5609 commit cbf4fe4
Show file tree
Hide file tree
Showing 11 changed files with 108 additions and 31,529 deletions.
9 changes: 9 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ llm:
## OpenAI-specific settings
# Only used for Azure OpenAI API
azure_openai_endpoint:
# for Azure OpenAI API deployment, the model may be deployed with custom deployment names
# set the deployment names if you encounter in logs HTTP 404 errors:
# "The API deployment for this resource does not exist."
azure_openai:
# llm.model name is used as deployment name as reasonable default if not set
# assuming base model is deployed with deployment name matching model name
# llm_deployment: "gpt-3.5-turbo-customname"
# embeddings deployment is required when Zep is configured to use OpenAI embeddings
# embedding_deployment: "text-embedding-ada-002-customname"
# Use only with an alternate OpenAI-compatible API endpoint
openai_endpoint:
openai_org_id:
Expand Down
20 changes: 13 additions & 7 deletions config/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@ type StoreConfig struct {
}

type LLM struct {
Service string `mapstructure:"service"`
Model string `mapstructure:"model"`
AnthropicAPIKey string `mapstructure:"anthropic_api_key"`
OpenAIAPIKey string `mapstructure:"openai_api_key"`
AzureOpenAIEndpoint string `mapstructure:"azure_openai_endpoint"`
OpenAIEndpoint string `mapstructure:"openai_endpoint"`
OpenAIOrgID string `mapstructure:"openai_org_id"`
Service string `mapstructure:"service"`
Model string `mapstructure:"model"`
AnthropicAPIKey string `mapstructure:"anthropic_api_key"`
OpenAIAPIKey string `mapstructure:"openai_api_key"`
AzureOpenAIEndpoint string `mapstructure:"azure_openai_endpoint"`
AzureOpenAIModel AzureOpenAIConfig `mapstructure:"azure_openai"`
OpenAIEndpoint string `mapstructure:"openai_endpoint"`
OpenAIOrgID string `mapstructure:"openai_org_id"`
}

type AzureOpenAIConfig struct {
LLMDeployment string `mapstructure:"llm_deployment"`
EmbeddingDeployment string `mapstructure:"embedding_deployment"`
}

type NLP struct {
Expand Down
8 changes: 4 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ require (
github.com/go-playground/validator/v10 v10.15.1
github.com/golang-jwt/jwt/v5 v5.0.0
github.com/google/uuid v1.3.1
github.com/hashicorp/go-retryablehttp v0.7.4
github.com/jinzhu/copier v0.4.0
github.com/joho/godotenv v1.5.1
github.com/oiime/logrusbun v0.1.1
Expand All @@ -31,7 +30,8 @@ require (
)

require (
github.com/tmc/langchaingo v0.0.0-20230823213549-ededff76a967
github.com/hashicorp/go-retryablehttp v0.7.4
github.com/tmc/langchaingo v0.0.0-20230827001633-72b07a1c060f
github.com/uptrace/bun/dbfixture v1.1.14
github.com/uptrace/bun/extra/bundebug v1.1.14
gopkg.in/yaml.v3 v3.0.1
Expand All @@ -48,14 +48,14 @@ require (
github.com/go-openapi/jsonpointer v0.19.6 // indirect
github.com/go-openapi/jsonreference v0.20.2 // indirect
github.com/go-openapi/spec v0.20.9 // indirect
github.com/go-openapi/swag v0.22.4 // indirect
github.com/go-openapi/swag v0.22.3 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgx/v5 v5.4.3 // indirect
github.com/jackc/pgx/v5 v5.4.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
Expand Down
13 changes: 6 additions & 7 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,8 @@ github.com/go-openapi/spec v0.20.9 h1:xnlYNQAwKd2VQRRfwTEI0DcK+2cbuvI/0c7jx3gA8/
github.com/go-openapi/spec v0.20.9/go.mod h1:2OpW+JddWPrpXSCIX8eOx7lZ5iyuWj3RYR6VaaBKcWA=
github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ=
github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/g=
github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
github.com/go-openapi/swag v0.22.4 h1:QLMzNJnMGPRNDCbySlcj1x01tzU8/9LTTL9hZZZogBU=
github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
github.com/go-pg/pg/v10 v10.11.0 h1:CMKJqLgTrfpE/aOVeLdybezR2om071Vh38OLZjsyMI0=
github.com/go-pg/zerochecker v0.2.0 h1:pp7f72c3DobMWOb2ErtZsnrPaSvHd2W4o9//8HtF4mU=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
Expand Down Expand Up @@ -189,8 +188,8 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackc/pgx/v5 v5.4.2 h1:u1gmGDwbdRUZiwisBm/Ky2M14uQyUP65bG8+20nnyrg=
github.com/jackc/pgx/v5 v5.4.2/go.mod h1:q6iHT8uDNXWiFNOlRqJzBTaSH3+2xCXkokxHZC5qWFY=
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
Expand Down Expand Up @@ -296,8 +295,8 @@ github.com/sv-tools/openapi v0.2.1 h1:ES1tMQMJFGibWndMagvdoo34T1Vllxr1Nlm5wz6b1a
github.com/sv-tools/openapi v0.2.1/go.mod h1:k5VuZamTw1HuiS9p2Wl5YIDWzYnHG6/FgPOSFXLAhGg=
github.com/swaggo/swag/v2 v2.0.0-rc3 h1:cIkbddJ9ftgRenDaDzyvg+2TUDLFCDffZ40yZE1r0vU=
github.com/swaggo/swag/v2 v2.0.0-rc3/go.mod h1:mfTZJmxpXWA3JQ9V381+cRlutUCo7OXd/VyIRcMhByc=
github.com/tmc/langchaingo v0.0.0-20230823213549-ededff76a967 h1:Zj1OLD2BYG8h0oQ40mlyyOGo9hNp21ebgy0HClbKK7w=
github.com/tmc/langchaingo v0.0.0-20230823213549-ededff76a967/go.mod h1:jwblKo3Lqe8r7UU9G+iQv2T/k33bLptDK+EabbC0zqk=
github.com/tmc/langchaingo v0.0.0-20230827001633-72b07a1c060f h1:nGEHKtcMfm/nci2w0Tn9vicH/xnjawBlbx4O1cLmrKw=
github.com/tmc/langchaingo v0.0.0-20230827001633-72b07a1c060f/go.mod h1:vCdA1t5qnS5YPkDsznowOziBHFn0Ul11ZqfJ2GOAi0s=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
github.com/uptrace/bun v0.3.9/go.mod h1:aL6D9vPw8DXaTQTwGrEPtUderBYXx7ShUmPfnxnqscw=
Expand Down Expand Up @@ -350,7 +349,7 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
golang.org/x/exp v0.0.0-20230510235704-dd950f8aeaea h1:vLCWI/yYrdEHyN2JzIzPO3aaQJHQdp89IZBA/+azVC4=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
Expand Down
39 changes: 39 additions & 0 deletions pkg/llms/llm_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,33 @@ var log = internal.GetLogger()
func NewLLMClient(ctx context.Context, cfg *config.Config) (models.ZepLLM, error) {
switch cfg.LLM.Service {
case "openai":
// Azure OpenAI model names can't be validated by any hard-coded models
// list as it is configured by custom deployment name that may or may not match the model name.
// We will copy the Model name value down to AzureOpenAI LLM Deployment
// to assume user deployed base model with matching deployment name as
// advised by Microsoft, but still support custom models or otherwise-named
// base model.
if cfg.LLM.AzureOpenAIEndpoint != "" {
if cfg.LLM.AzureOpenAIModel.LLMDeployment != "" {
cfg.LLM.Model = cfg.LLM.AzureOpenAIModel.LLMDeployment
}
if cfg.LLM.Model == "" {
return nil, fmt.Errorf(
"invalid llm deployment for %s, deployment name is required",
cfg.LLM.Service,
)
}

// EmbeddingsDeployment is only required if Zep is also configured to use
// OpenAI embeddings for document or message extractors
if cfg.LLM.AzureOpenAIModel.EmbeddingDeployment == "" && useOpenAIEmbeddings(cfg) {
return nil, fmt.Errorf(
"invalid embeddings deployment for %s, deployment name is required",
cfg.LLM.Service,
)
}
return NewOpenAILLM(ctx, cfg)
}
if _, ok := ValidOpenAILLMs[cfg.LLM.Model]; !ok {
return nil, fmt.Errorf(
"invalid llm model \"%s\" for %s",
Expand Down Expand Up @@ -110,3 +137,15 @@ func NewRetryableHTTPClient() *retryablehttp.Client {

return retryableHttpClient
}

// useOpenAIEmbeddings is true if OpenAI embeddings are enabled
func useOpenAIEmbeddings(cfg *config.Config) bool {
switch {
case cfg.Extractors.Messages.Embeddings.Enabled:
return cfg.Extractors.Messages.Embeddings.Service == "openai"
case cfg.Extractors.Documents.Embeddings.Enabled:
return cfg.Extractors.Documents.Embeddings.Service == "openai"
}

return false
}
6 changes: 6 additions & 0 deletions pkg/llms/llm_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ func (zllm *ZepOpenAILLM) configureClient(cfg *config.Config) ([]openai.Option,
openai.WithAPIType(openai.APITypeAzure),
openai.WithBaseURL(cfg.LLM.AzureOpenAIEndpoint),
)
if cfg.LLM.AzureOpenAIModel.EmbeddingDeployment != "" {
options = append(
options,
openai.WithEmbeddingModel(cfg.LLM.AzureOpenAIModel.EmbeddingDeployment),
)
}
case cfg.LLM.OpenAIEndpoint != "":
// If an alternate OpenAI-compatible endpoint URL is set, use this as the base URL for requests
options = append(
Expand Down
21 changes: 21 additions & 0 deletions pkg/llms/llm_openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,27 @@ func TestZepOpenAILLM_TestConfigureClient(t *testing.T) {
}
})

t.Run("Test with AzureOpenAIEmbeddingModel", func(t *testing.T) {
cfg := &config.Config{
LLM: config.LLM{
OpenAIAPIKey: "test-key",
AzureOpenAIEndpoint: "https://azure.openai.com",
AzureOpenAIModel: config.AzureOpenAIConfig{
EmbeddingDeployment: "test-deployment",
},
},
}

options, err := zllm.configureClient(cfg)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

if len(options) != 6 {
t.Errorf("Expected 6 options, got %d", len(options))
}
})

t.Run("Test with OpenAIEndpoint", func(t *testing.T) {
cfg := &config.Config{
LLM: config.LLM{
Expand Down
20 changes: 10 additions & 10 deletions pkg/server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ func Create(appState *models.AppState) *http.Server {
}
}

// @title Zep REST-like API
// @version 0.x
// @license.name Apache 2.0
// @license.url http://www.apache.org/licenses/LICENSE-2.0.html
// @BasePath /api/v1
// @schemes http https
// @securityDefinitions.apikey Bearer
// @in header
// @name Authorization
// @description Type "Bearer" followed by a space and JWT token.
// @title Zep REST-like API
// @version 0.x
// @license.name Apache 2.0
// @license.url http://www.apache.org/licenses/LICENSE-2.0.html
// @BasePath /api/v1
// @schemes http https
// @securityDefinitions.apikey Bearer
// @in header
// @name Authorization
// @description Type "Bearer" followed by a space and JWT token.
func setupRouter(appState *models.AppState) *chi.Mux {
router := chi.NewRouter()
router.Use(httpLogger.Logger("router", log))
Expand Down
Loading

0 comments on commit cbf4fe4

Please sign in to comment.