forked from ChessScholar/Auto-PaLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
memory.py
38 lines (32 loc) · 1.24 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# memory.py
import os
import json
import re
import time
import numpy as np
import google.generativeai as palm
class Memory:
def __init__(self):
self.prompts = set()
self.questions = {}
self.answers = {}
self.context = []
def get_embedding(self, text):
embedding_models = [m for m in palm.list_models() if 'embedText' in m.supported_generation_methods]
embedding_model = embedding_models[0].name
embedding = palm.generate_embeddings(model=embedding_model, text=text)['embedding']
return (text, embedding)
def update_prompt(self, t):
self.prompts.add(t)
_, t_embedding = self.get_embedding(t)
self.context.append(("", t_embedding))
def update_question(self, q, a):
self.questions[q] = a
self.answers[q] = a
_, a_embedding = self.get_embedding(a)
self.context.append((q, a_embedding))
def filter_context(self, prompt_embedding, threshold=0.1):
self.context = list(filter(lambda item: np.dot(item[1], prompt_embedding) >= threshold, self.context))
return self.context
def cosine_similarity(self, a, b):
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))