-
Notifications
You must be signed in to change notification settings - Fork 0
/
embeddings.py
145 lines (110 loc) · 4.35 KB
/
embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
# from langchain.text_splitter import RecursiveCharacterTextSplitter
# from langchain_community.vectorstores import Chroma
from langchain_google_vertexai import VertexAIEmbeddings
import mimetypes
# from langchain.document_loaders import PyPDFLoader, TextLoader
# from google.cloud import storage
from langchain.vectorstores.utils import DistanceStrategy
from langchain_community.vectorstores import BigQueryVectorSearch
# def upload_to_gcs(destination_blob_name, file_path):
# bucket_name = 'durable-return-430917-rag'
# client = storage.Client()
# bucket = client.bucket(bucket_name)
# blob = bucket.blob(destination_blob_name)
# blob.upload_from_filename(file_path)
# def load_and_split_document(file_path):
# # Determine the MIME type of the file
# mime_type, _ = mimetypes.guess_type(file_path)
# # Select the appropriate loader based on the file type
# if mime_type == 'application/pdf':
# loader = PyPDFLoader(file_path)
# elif mime_type == 'text/plain':
# loader = TextLoader(file_path)
# else:
# raise ValueError(f"Unsupported file type: {mime_type}")
# # Load the document
# document = loader.load()
# # Split the document into chunks
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=2500, chunk_overlap=300)
# chunks = text_splitter.split_documents(document)
# original_file_name = os.path.basename(file_path)
# upload_to_gcs(original_file_name, file_path)
# save_to_chroma(chunks)
# def save_to_chroma(chunks):
# embeddings = VertexAIEmbeddings(
# model_name="textembedding-gecko",
# batch_size=1,
# requests_per_minute=60
# )
# db = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
# batch_size = 10 # Adjust as needed
# for i in range(0, len(chunks), batch_size):
# batch = chunks[i:i+batch_size]
# texts = []
# metadatas = []
# ids = []
# curr_sum = 0
# lastidx = ''
# for chunk in batch:
# if chunk.page_content.strip():
# texts.append(chunk.page_content)
# metadatas.append(chunk.metadata)
# chunksrc = chunk.metadata.get("source").split('\\')[-1]
# if chunksrc == lastidx:
# curr_sum+=1
# id = chunksrc + str(curr_sum)
# else:
# curr_sum = 1
# id = chunksrc + str(curr_sum)
# lastidx = chunksrc
# ids.append(id)
# if texts:
# try:
# db.add_texts(texts=texts, metadatas=metadatas,ids=ids)
# except Exception as e:
# print(f"Error adding batch {i//batch_size + 1} to Chroma: {e}")
# db.persist()
# return db
# def query_rag(query_text: str):
# # Prepare the DB.
# embedding_function = VertexAIEmbeddings(
# model_name="textembedding-gecko",
# batch_size=1,
# requests_per_minute=60
# )
# db = Chroma(persist_directory='chroma_db', embedding_function=embedding_function)
# # print(len(db.get()['ids']))
# # Search the DB.
# results = db.similarity_search_with_score(query_text, k=1)
# documents = []
# for doc,_score in results:
# documents.append(doc.metadata["source"].split('\\')[-1])
# print(documents)
# docs = []
# for document in documents:
# docs.append(document+"1")
# docs.append(document+"2")
# doc_result = db.get(ids=docs)
# query_docs = doc_result["documents"]
# context_text = "\n\n---\n\n".join([doc for doc in query_docs])
# return context_text
# load_and_split_document('requirements.txt')
def query_bq(query_text: str):
PROJECT_ID = 'durable-return-430917-b5'
embeddings = VertexAIEmbeddings(
model_name="textembedding-gecko",
batch_size=1,
requests_per_minute=60
)
store = BigQueryVectorSearch(
project_id=PROJECT_ID,
dataset_name='vector_db',
table_name='test',
embedding=embeddings,
distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE,
)
docs = store.similarity_search(query_text,k=1)
if docs:
return docs[0].page_content
return None