Skip to content

Commit

Permalink
Update plugin cfg, configure deps, fix schema (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
mszostok authored Feb 26, 2024
1 parent 849af53 commit c114f42
Show file tree
Hide file tree
Showing 7 changed files with 360 additions and 260 deletions.
179 changes: 66 additions & 113 deletions internal/source/ai-brain/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/kubeshop/botkube/pkg/api/source"
"github.com/sashabaranov/go-openai"
"github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/util/wait"
)

const openAIPollInterval = 2 * time.Second
Expand All @@ -21,9 +22,36 @@ type Payload struct {
MessageID string `json:"messageId"`
}

func (i *sourceInstance) handle(in source.ExternalRequestInput) (api.Message, error) {
p := new(Payload)
err := json.Unmarshal(in.Payload, p)
type assistant struct {
log logrus.FieldLogger
out chan<- source.Event
openaiClient *openai.Client
assistID string
tools map[string]tool
threadMapping map[string]string
}

func newAssistant(cfg *Config, log logrus.FieldLogger, out chan source.Event, kubeConfigPath string) *assistant {
kcRunner := NewKubectlRunner(kubeConfigPath)

return &assistant{
log: log,
out: out,
openaiClient: openai.NewClient(cfg.OpenAIAPIKey),
assistID: cfg.OpenAIAssistantID,
threadMapping: make(map[string]string),
tools: map[string]tool{
"kubectlGetPods": kcRunner.GetPods,
"kubectlGetSecrets": kcRunner.GetSecrets,
"kubectlDescribePod": kcRunner.DescribePod,
"kubectlLogs": kcRunner.Logs,
},
}
}

func (i *assistant) handle(in source.ExternalRequestInput) (api.Message, error) {
var p Payload
err := json.Unmarshal(in.Payload, &p)
if err != nil {
return api.Message{}, fmt.Errorf("while unmarshalling payload: %w", err)
}
Expand All @@ -33,7 +61,11 @@ func (i *sourceInstance) handle(in source.ExternalRequestInput) (api.Message, er
}

go func() {
if err := i.handleThread(context.Background(), p); err != nil {
if err := i.handleThread(context.Background(), &p); err != nil {
// TODO: It would be great to send the user prompt and error message
// back to us for analysis and potential fixing, enhancing our prompt.
// can we do that @Blair?
i.out <- source.Event{Message: msgUnableToHelp(p.MessageID)}
i.log.WithError(err).Error("failed to handle request")
}
}()
Expand All @@ -42,8 +74,7 @@ func (i *sourceInstance) handle(in source.ExternalRequestInput) (api.Message, er
}

// handleThread creates a new OpenAI assistant thread and handles the conversation.
func (i *sourceInstance) handleThread(ctx context.Context, p *Payload) error {
// Start a new thread and run it.
func (i *assistant) handleThread(ctx context.Context, p *Payload) error {
run, err := i.openaiClient.CreateThreadAndRun(ctx, openai.CreateThreadAndRunRequest{
RunRequest: openai.RunRequest{
AssistantID: i.assistID,
Expand All @@ -64,15 +95,12 @@ func (i *sourceInstance) handleThread(ctx context.Context, p *Payload) error {
return fmt.Errorf("while creating thread and run: %w", err)
}

for {
// Wait a little bit before polling. OpenAI assistant api does not support streaming yet.
time.Sleep(openAIPollInterval)
i.threadMapping[p.MessageID] = run.ID

// Get the run.
return wait.PollUntilContextCancel(ctx, openAIPollInterval, false, func(ctx context.Context) (bool, error) {
run, err = i.openaiClient.RetrieveRun(ctx, run.ThreadID, run.ID)
if err != nil {
i.out <- source.Event{Message: msgUnableToHelp(p.MessageID)}
return fmt.Errorf("while retrieving assistant thread run: %w", err)
return false, fmt.Errorf("while retrieving assistant thread run: %w", err)
}

i.log.WithFields(logrus.Fields{
Expand All @@ -82,31 +110,27 @@ func (i *sourceInstance) handleThread(ctx context.Context, p *Payload) error {

switch run.Status {
case openai.RunStatusCancelling, openai.RunStatusFailed, openai.RunStatusExpired:
i.out <- source.Event{Message: msgUnableToHelp(p.MessageID)}
return nil
return false, fmt.Errorf("got unexpected status: %s", run.Status)

case openai.RunStatusQueued, openai.RunStatusInProgress:
continue
return false, nil // continue

// Fetch and return the response.
case openai.RunStatusCompleted:
if err = i.handleStatusCompleted(ctx, run, p); err != nil {
i.out <- source.Event{Message: msgUnableToHelp(p.MessageID)}
return fmt.Errorf("while handling completed case: %w", err)
return false, fmt.Errorf("while handling completed case: %w", err)
}
return nil
return true, nil // success

// The assistant is attempting to call a function.
case openai.RunStatusRequiresAction:
if err = i.handleStatusRequiresAction(ctx, run); err != nil {
i.out <- source.Event{Message: msgUnableToHelp(p.MessageID)}
return fmt.Errorf("while handling requires action: %w", err)
return false, fmt.Errorf("while handling requires action: %w", err)
}
}
}
return false, nil
})
}

func (i *sourceInstance) handleStatusCompleted(ctx context.Context, run openai.Run, p *Payload) error {
func (i *assistant) handleStatusCompleted(ctx context.Context, run openai.Run, p *Payload) error {
limit := 1
msgList, err := i.openaiClient.ListMessage(ctx, run.ThreadID, &limit, nil, nil, nil)
if err != nil {
Expand All @@ -119,16 +143,7 @@ func (i *sourceInstance) handleStatusCompleted(ctx context.Context, run openai.R
if len(msgList.Messages) == 0 {
i.log.Debug("no response messages were found, that seems like an edge case.")
i.out <- source.Event{
Message: api.Message{
ParentActivityID: p.MessageID,
Sections: []api.Section{
{
Base: api.Base{
Body: api.Body{Plaintext: "I am sorry, but I don't have a good answer."},
},
},
},
},
Message: msgNoAIAnswer(p.MessageID),
}
return nil
}
Expand All @@ -141,103 +156,41 @@ func (i *sourceInstance) handleStatusCompleted(ctx context.Context, run openai.R
}

i.out <- source.Event{
Message: api.Message{
ParentActivityID: p.MessageID,
Sections: []api.Section{
{
Base: api.Base{
Body: api.Body{Plaintext: c.Text.Value},
},
Context: []api.ContextItem{
{Text: "AI-generated content may be incorrect."},
},
},
},
},
Message: msgAIAnswer(p.MessageID, c.Text.Value),
}
}

return nil
}

func (i *sourceInstance) handleStatusRequiresAction(ctx context.Context, run openai.Run) error {
type tool func(ctx context.Context, args []byte) (string, error)

func (i *assistant) handleStatusRequiresAction(ctx context.Context, run openai.Run) error {
// That should never happen, unless there is a bug or something is wrong with OpenAI APIs.
if run.RequiredAction == nil || run.RequiredAction.SubmitToolOutputs == nil {
return errors.New("run.RequiredAction or run.RequiredAction.SubmitToolOutputs is nil, that should not happen")
}

toolOutputs := []openai.ToolOutput{}

var toolOutputs []openai.ToolOutput
for _, t := range run.RequiredAction.SubmitToolOutputs.ToolCalls {
if t.Type != openai.ToolTypeFunction {
continue
}

switch t.Function.Name {
case "kubectlGetPods":
args := &kubectlGetPodsArgs{}
if err := json.Unmarshal([]byte(t.Function.Arguments), args); err != nil {
return err
}

out, err := kubectlGetPods(args)
if err != nil {
return err
}

toolOutputs = append(toolOutputs, openai.ToolOutput{
ToolCallID: t.ID,
Output: string(out),
})

case "kubectlGetSecrets":
args := &kubectlGetSecretsArgs{}
if err := json.Unmarshal([]byte(t.Function.Arguments), args); err != nil {
return err
}

out, err := kubectlGetSecrets(args)
if err != nil {
return err
}

toolOutputs = append(toolOutputs, openai.ToolOutput{
ToolCallID: t.ID,
Output: string(out),
})

case "kubectlDescribePod":
args := &kubectlDescribePodArgs{}
if err := json.Unmarshal([]byte(t.Function.Arguments), args); err != nil {
return err
}

out, err := kubectlDescribePod(args)
if err != nil {
return err
}

toolOutputs = append(toolOutputs, openai.ToolOutput{
ToolCallID: t.ID,
Output: string(out),
})

case "kubectlLogs":
args := &kubectlLogsArgs{}
if err := json.Unmarshal([]byte(t.Function.Arguments), args); err != nil {
return err
}

out, err := kubectlLogs(args)
if err != nil {
return err
}
doer, found := i.tools[t.Function.Name]
if !found {
continue
}

toolOutputs = append(toolOutputs, openai.ToolOutput{
ToolCallID: t.ID,
Output: string(out),
})
out, err := doer(ctx, []byte(t.Function.Arguments))
if err != nil {
return err
}

toolOutputs = append(toolOutputs, openai.ToolOutput{
ToolCallID: t.ID,
Output: out,
})
}

_, err := i.openaiClient.SubmitToolOutputs(ctx, run.ThreadID, run.ID, openai.SubmitToolOutputsRequest{
Expand Down
Loading

0 comments on commit c114f42

Please sign in to comment.