Skip to content

Commit

Permalink
prompts: FewShotPrompt implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
AMK9978 committed Jul 17, 2023
1 parent dcf7ecd commit c5d9054
Show file tree
Hide file tree
Showing 2 changed files with 350 additions and 0 deletions.
186 changes: 186 additions & 0 deletions prompts/few_shot.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
package prompts

import (
"errors"
"fmt"
"strings"
)

var (
// ErrNoExample is returned when none of the Examples and ExampleSelector are provided.
ErrNoExample = errors.New("no example is provided")
// ErrExamplesAndExampleSelectorProvided is returned when there are no Examples and ExampleSelector.
ErrExamplesAndExampleSelectorProvided = errors.New("only one of 'Examples' and 'example_selector' should be" +
" provided")
)

// FewShotPrompt contains fields for a few-shot prompt.
type FewShotPrompt struct {
// Examples to format into the prompt. Either this or ExamplePrompt should be provided.
Examples []map[string]string
// ExampleSelector to choose the examples to format into the prompt. Either this or Examples should be provided.
ExampleSelector ExampleSelector
// ExamplePrompt is used to format an individual example.
ExamplePrompt PromptTemplate
// A prompt template string to put before the examples.
Prefix string
// A prompt template string to put after the examples.
Suffix string
// A list of the names of the variables the prompt template expects.
InputVariables map[string]any
// Represents a map of variable names to values or functions that return values. If the value is a function, it will
// be called when the prompt template is rendered.
PartialVariables map[string]any
// String separator used to join the prefix, the examples, and suffix.
ExampleSeparator string
// The format of the prompt template. Options are: 'f-string', 'jinja2'.
TemplateFormat TemplateFormat
// Whether to try validating the template.
ValidateTemplate bool
}

// FewShotCallOptions contains optional fields for FewShotPrompt.
type FewShotCallOptions struct {
// ExampleSeparator separates between prefix, examples and suffix. Default is "\n\n"
ExampleSeparator string
// TemplateFormat is the format of the template
TemplateFormat TemplateFormat
// ValidateTemplate causes validations on prefix, suffix,templateFormat, input, and partial variables
ValidateTemplate bool
}

// NewFewShotPrompt creates a new few-shot prompt with the given input. It applies the optional input if provided. It
// returns error if there is no example, both examples and exampleSelector are provided, or CheckValidTemplate returns
// err when ValidateTemplate is true.
func NewFewShotPrompt(examplePrompt PromptTemplate, examples []map[string]string, exampleSelector ExampleSelector,
prefix string, suffix string, input map[string]interface{}, partialInput map[string]interface{},
options ...FewShotCallOptions,
) (*FewShotPrompt, error) {
err := validateExamples(examples, exampleSelector)
if err != nil {
return nil, err
}
prompt := &FewShotPrompt{
ExamplePrompt: examplePrompt,
Prefix: prefix,
Suffix: suffix,
InputVariables: input,
PartialVariables: partialInput,
Examples: examples,
ExampleSelector: exampleSelector,
ExampleSeparator: "\n\n",
TemplateFormat: TemplateFormatGoTemplate,
ValidateTemplate: true,
}

if len(options) > 0 {
option := options[0]
if option.ExampleSeparator != "" {
prompt.ExampleSeparator = option.ExampleSeparator
}
if option.TemplateFormat != "" {
prompt.TemplateFormat = option.TemplateFormat
}
if !option.ValidateTemplate {
prompt.ValidateTemplate = false
}
}

if prompt.ValidateTemplate {
err := CheckValidTemplate(prompt.Prefix+prompt.Suffix, prompt.TemplateFormat, append(getMapKeys(input),
getMapKeys(partialInput)...))
if err != nil {
return nil, err
}
}
return prompt, nil
}

// validateExamples validates the provided example and exampleSelector. One of them must be provided only.
func validateExamples(examples []map[string]string, exampleSelector ExampleSelector) error {
if examples != nil && exampleSelector != nil {
return ErrExamplesAndExampleSelectorProvided
} else if examples == nil && exampleSelector == nil {
return ErrNoExample
}
return nil
}

// getExamples returns the provided examples or returns error when there is no example.
func (p *FewShotPrompt) getExamples(input map[string]string) ([]map[string]string, error) {
switch {
case p.Examples != nil:
return p.Examples, nil
case p.ExampleSelector != nil:
return p.ExampleSelector.SelectExamples(input), nil
default:
return nil, ErrNoExample
}
}

// Format assembles and formats the pieces of the prompt with the given input values and partial values.
func (p *FewShotPrompt) Format(values map[string]interface{}) (string, error) {
resolvedValues, err := resolvePartialValues(p.PartialVariables, values)
if err != nil {
return "", err
}
stringResolvedValues := map[string]string{}
for k, v := range resolvedValues {
strVal, ok := v.(string)
if !ok {
return "", fmt.Errorf("%w: %T", ErrInvalidPartialVariableType, v)
}
stringResolvedValues[k] = strVal
}
examples, err := p.getExamples(stringResolvedValues)
if err != nil {
return "", err
}
exampleStrings := make([]string, len(examples))

for i, example := range examples {
exampleMap := make(map[string]interface{})
for k, v := range example {
exampleMap[k] = v
}

res, err := p.ExamplePrompt.Format(exampleMap)
if err != nil {
return "", err
}
exampleStrings[i] = res
}

template := assemblePieces(p.Prefix, p.Suffix, exampleStrings, p.ExampleSeparator)
return defaultformatterMapping[p.TemplateFormat](template, resolvedValues)
}

// assemblePieces assembles the pieces of the few-shot prompt.
func assemblePieces(prefix, suffix string, exampleStrings []string, separator string) string {
const additionalCapacity = 2
pieces := make([]string, 0, len(exampleStrings)+additionalCapacity)
if prefix != "" {
pieces = append(pieces, prefix)
}

for _, elem := range exampleStrings {
if elem != "" {
pieces = append(pieces, elem)
}
}

if suffix != "" {
pieces = append(pieces, suffix)
}

return strings.Join(pieces, separator)
}

// getMapKeys returns the keys of the provided map.
func getMapKeys(inputMap map[string]any) []string {
keys := make([]string, 0, len(inputMap))
for k := range inputMap {
keys = append(keys, k)
}
return keys
}
164 changes: 164 additions & 0 deletions prompts/few_shot_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package prompts

import (
"fmt"
"testing"

"github.com/google/go-cmp/cmp"
)

func TestFewShotPrompt_Format(t *testing.T) {
examplePrompt := NewPromptTemplate("{{.question}}: {{.answer}}", []string{"question", "answer"})
t.Parallel()
testCases := []struct {
name string
examplePrompt PromptTemplate
examples []map[string]string
exampleSelector ExampleSelector
prefix string
suffix string
input map[string]interface{}
partialInput map[string]interface{}
options FewShotCallOptions
wantErr bool
expected string
}{
{
"Prefix only",
examplePrompt,
[]map[string]string{},
nil,
"This is a {{.foo}} test.",
"",
map[string]interface{}{"foo": "bar"},
nil,
FewShotCallOptions{},
false,
"This is a bar test.",
},
{
"Suffix only",
examplePrompt,
[]map[string]string{},
nil,
"",
"This is a {{.foo}} test.",
map[string]interface{}{"foo": "bar"},
nil,
FewShotCallOptions{},
false,
"This is a bar test.",
},
{
"insufficient InputVariables w err",
examplePrompt,
[]map[string]string{},
nil,
"",
"This is a {{.foo}} test.",
map[string]interface{}{"bar": "bar"},
nil,
FewShotCallOptions{},
true,
"This is a bar test.",
},
{
"InputVariables neither Examples nor ExampleSelector w err",
examplePrompt,
nil,
nil,
"",
"",
map[string]interface{}{"bar": "bar"},
nil,
FewShotCallOptions{},
true,
ErrNoExample.Error(),
},
{
"functionality test",
examplePrompt,
[]map[string]string{
{"question": "foo", "answer": "bar"},
{"question": "baz", "answer": "foo"},
},
nil,
"This is a test about {{.content}}.",
"Now you try to talk about {{.new_content}}.",
map[string]interface{}{"content": "animals", "new_content": "party"},
nil,
FewShotCallOptions{
ExampleSeparator: "\n",
},
false,
"This is a test about animals.\nfoo: bar\nbaz: foo\nNow you try to talk about party.",
},
{
"functionality test",
examplePrompt,
[]map[string]string{
{"question": "foo", "answer": "bar"},
{"question": "baz", "answer": "foo"},
},
nil,
"This is a test about {{.content}}.",
"Now you try to talk about {{.new_content}}.",
map[string]interface{}{"content": "animals"},
map[string]interface{}{"new_content": func() string { return "party" }},
FewShotCallOptions{
ExampleSeparator: "\n",
ValidateTemplate: true,
},
false,
"This is a test about animals.\nfoo: bar\nbaz: foo\nNow you try to talk about party.",
},
{
"invalid template w err",
examplePrompt,
[]map[string]string{
{"question": "foo", "answer": "bar"},
{"question": "baz", "answer": "foo"},
},
nil,
"This is a test about {{.wrong_content}}.",
"Now you try to talk about {{.new_content}}.",
map[string]interface{}{"content": "animals"},
map[string]interface{}{"new_content": func() string { return "party" }},
FewShotCallOptions{
ExampleSeparator: "\n",
ValidateTemplate: true,
},
true,
"template: template:1:23: executing \"template\" at <.wrong_content>: map has no entry for key " +
"\"wrong_content\"",
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
p, err := NewFewShotPrompt(tc.examplePrompt, tc.examples, tc.exampleSelector, tc.prefix, tc.suffix, tc.input,
tc.partialInput, tc.options)
if tc.wantErr && err != nil {
if err.Error() == tc.expected {
return
}
t.Errorf("FewShotPrompt.Format() error = %v, wantErr %v", err, tc.wantErr)
return
}
fp, err := p.Format(tc.input)
if (err != nil) != tc.wantErr {
t.Errorf("FewShotPrompt.Format() error = %v, wantErr %v", err, tc.wantErr)
return
}
if tc.wantErr {
return
}
got := fmt.Sprint(fp)
if cmp.Diff(tc.expected, got) != "" {
t.Errorf("unexpected prompt output (-want +got):\n%s", cmp.Diff(tc.expected, got))
}
})
}
}

0 comments on commit c5d9054

Please sign in to comment.