Skip to content

Commit

Permalink
feat: Alternative OpenAI Endpoint URL (#126)
Browse files Browse the repository at this point in the history
* feat: Use a custom URL for OpenAI API endpoints

* refactor OpenAI config

* minimal tests

---------

Co-authored-by: Drew Baumann <[email protected]>
  • Loading branch information
danielchalef and drewbaumann authored Jul 8, 2023
1 parent acc0312 commit 2f55d2c
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 12 deletions.
3 changes: 3 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
llm:
# gpt-3.5-turbo or gpt-4
model: "gpt-3.5-turbo"
# Only used for Azure OpenAI API
azure_openai_endpoint:
# Use only with an alternate OpenAI-compatible API endpoint
openai_endpoint:
openai_org_id:
nlp:
server_url: "http://localhost:8080"
Expand Down
1 change: 1 addition & 0 deletions config/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type LLM struct {
// OpenAIAPIKey is loaded from ENV not config file.
OpenAIAPIKey string `mapstructure:"openai_api_key"`
AzureOpenAIEndpoint string `mapstructure:"azure_openai_endpoint"`
OpenAIEndpoint string `mapstructure:"openai_endpoint"`
OpenAIOrgID string `mapstructure:"openai_org_id"`
}

Expand Down
5 changes: 5 additions & 0 deletions http-client.private.env.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"dev": {
"host": "http://localhost:8000"
}
}
10 changes: 4 additions & 6 deletions pkg/llms/llm_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@ func NewLLMError(message string, originalError error) *LLMError {
}

var MaxLLMTokensMap = map[string]int{
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo0301": 8192,
"gpt-4": 8192,
"gpt-3.5-turbo": 4096,
"gpt-4": 8192,
}

var ValidLLMMap = map[string]bool{
"gpt-3.5-turbo": true,
"gpt-3.5-turbo0301": true,
"gpt-4": true,
"gpt-3.5-turbo": true,
"gpt-4": true,
}
28 changes: 22 additions & 6 deletions pkg/llms/llm_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,38 @@ var (
)

func NewOpenAIRetryClient(cfg *config.Config) *openairetryclient.OpenAIRetryClient {
// Retrieve the OpenAIAPIKey from configuration
apiKey := cfg.LLM.OpenAIAPIKey
// If the key is not set, log a fatal error and exit
if apiKey == "" {
log.Fatal(OpenAIAPIKeyNotSetError)
}
if cfg.LLM.AzureOpenAIEndpoint != "" && cfg.LLM.OpenAIEndpoint != "" {
log.Fatal("only one of AzureOpenAIEndpoint or OpenAIEndpoint can be set")
}

// Initiate the openAIClientConfig with the default configuration
openAIClientConfig := openai.DefaultConfig(apiKey)

var openAIClientConfig openai.ClientConfig
azureEndpoint := cfg.LLM.AzureOpenAIEndpoint
if azureEndpoint != "" {
openAIClientConfig = openai.DefaultAzureConfig(apiKey, azureEndpoint)
} else {
openAIClientConfig = openai.DefaultConfig(apiKey)
switch {
case cfg.LLM.AzureOpenAIEndpoint != "":
// Check configuration for AzureOpenAIEndpoint; if it's set, use the DefaultAzureConfig
// and provided endpoint URL
openAIClientConfig = openai.DefaultAzureConfig(apiKey, cfg.LLM.AzureOpenAIEndpoint)
case cfg.LLM.OpenAIEndpoint != "":
// If an alternate OpenAI-compatible endpoint URL is set, use this as the base URL for requests
openAIClientConfig.BaseURL = cfg.LLM.OpenAIEndpoint
default:
// If no specific endpoints are defined, use the default configuration with the OpenAIOrgID
// This optional and may just be an empty string
openAIClientConfig.OrgID = cfg.LLM.OpenAIOrgID
}

// Create a new client instance with the final openAIClientConfig
client := openai.NewClientWithConfig(openAIClientConfig)

// Return a new retry client. This client contains a pre-configured OpenAI client
// and additional retry logic (timeout duration and maximum number of attempts)
return &openairetryclient.OpenAIRetryClient{
Client: *client,
Config: struct {
Expand Down
57 changes: 57 additions & 0 deletions pkg/llms/llm_openai_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package llms

import (
"testing"
"time"

"github.com/getzep/zep/config"
"github.com/getzep/zep/pkg/llms/openairetryclient"
"github.com/stretchr/testify/assert"
)

// Minimal set of test cases. We'd need to refactor the error states to not immediately
// exit the program to test more thoroughly.

// Test with a valid Azure configuration.
func TestNewOpenAIRetryClient_ValidAzureConfig(t *testing.T) {
cfg := &config.Config{
LLM: config.LLM{
OpenAIAPIKey: "testKey",
AzureOpenAIEndpoint: "azureEndpoint",
},
}

client := NewOpenAIRetryClient(cfg)
assert.IsType(t, &openairetryclient.OpenAIRetryClient{}, client)
assert.IsType(t, time.Duration(0), client.Config.Timeout)
assert.Equal(t, uint(5), client.Config.MaxAttempts)
}

// Test with a valid configuration.
func TestNewOpenAIRetryClient_ValidConfig(t *testing.T) {
cfg := &config.Config{
LLM: config.LLM{
OpenAIAPIKey: "testKey",
},
}

client := NewOpenAIRetryClient(cfg)
assert.IsType(t, &openairetryclient.OpenAIRetryClient{}, client)
assert.IsType(t, time.Duration(0), client.Config.Timeout)
assert.Equal(t, uint(5), client.Config.MaxAttempts)
}

// Test with a valid configuration.
func TestNewOpenAIRetryClient_ValidConfigCustomEndpoint(t *testing.T) {
cfg := &config.Config{
LLM: config.LLM{
OpenAIAPIKey: "testKey",
OpenAIEndpoint: "https://api.openai.com/v1",
},
}

client := NewOpenAIRetryClient(cfg)
assert.IsType(t, &openairetryclient.OpenAIRetryClient{}, client)
assert.IsType(t, time.Duration(0), client.Config.Timeout)
assert.Equal(t, uint(5), client.Config.MaxAttempts)
}

0 comments on commit 2f55d2c

Please sign in to comment.