forked from hegelai/prompttools
-
Notifications
You must be signed in to change notification settings - Fork 0
/
example.py
50 lines (41 loc) · 1.42 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Copyright (c) Hegel AI, Inc.
# All rights reserved.
#
# This source code's license can be found in the
# LICENSE file in the root directory of this source tree.
import os
from typing import Dict, Tuple
import prompttools.testing.prompttest as prompttest
from prompttools.utils import similarity
EXPECTED = {"Who was the first president of the USA?": "George Washington"}
if not (("OPENAI_API_KEY" in os.environ) or ("DEBUG" in os.environ)):
print(
"Error: This example requires you to set either your OPENAI_API_KEY or DEBUG=1"
)
exit(1)
def extract_responses(output) -> str:
r"""
Helper function to unwrap the OpenAI repsonse object.
"""
return [choice["text"] for choice in output["choices"]]
@prompttest.prompttest(
model_name="text-davinci-003",
metric_name="similar_to_expected",
prompt_template="Answer the following question: {{input}}",
user_input=[{"input": "Who was the first president of the USA?"}],
)
def measure_similarity(
input_pair: Tuple[str, Dict[str, str]], results: Dict, metadata: Dict
) -> float:
r"""
A simple test that checks semantic similarity between the user input
and the model's text responses.
"""
expected = EXPECTED[input_pair[1]["input"]]
distances = [
similarity.compute(expected, response)
for response in extract_responses(results)
]
return min(distances)
if __name__ == "__main__":
prompttest.main()