Skip to content

Commit

Permalink
Merge pull request #395 from zyclove/main
Browse files Browse the repository at this point in the history
feat:add OpenSearch vector database to be supportted
  • Loading branch information
zainhoda authored Apr 30, 2024
2 parents 0703021 + 86d0221 commit 4c691e5
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ zhipuai = ["zhipuai"]
ollama = ["ollama", "httpx"]
qdrant = ["qdrant-client"]
vllm = ["vllm"]
opensearch = ["opensearch-py", "opensearch-dsl"]
1 change: 1 addition & 0 deletions src/vanna/opensearch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .opensearch_vector import OpenSearch_VectorStore
289 changes: 289 additions & 0 deletions src/vanna/opensearch/opensearch_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
import base64
import uuid
from typing import List

import pandas as pd
from opensearchpy import OpenSearch

from ..base import VannaBase


class OpenSearch_VectorStore(VannaBase):
def __init__(self, config=None):
VannaBase.__init__(self, config=config)
document_index = "vanna_document_index"
ddl_index = "vanna_ddl_index"
question_sql_index = "vanna_questions_sql_index"
if config is not None and "es_document_index" in config:
document_index = config["es_document_index"]
if config is not None and "es_ddl_index" in config:
ddl_index = config["es_ddl_index"]
if config is not None and "es_question_sql_index" in config:
question_sql_index = config["es_question_sql_index"]

self.document_index = document_index
self.ddl_index = ddl_index
self.question_sql_index = question_sql_index
print("OpenSearch_VectorStore initialized with document_index: ", document_index, " ddl_index: ", ddl_index, " question_sql_index: ", question_sql_index)

es_urls = None
if config is not None and "es_urls" in config:
es_urls = config["es_urls"]

# Host and port
if config is not None and "es_host" in config:
host = config["es_host"]
else:
host = "localhost"

if config is not None and "es_port" in config:
port = config["es_port"]
else:
port = 9200

if config is not None and "es_ssl" in config:
ssl = config["es_ssl"]
else:
ssl = False

if config is not None and "es_verify_certs" in config:
verify_certs = config["es_verify_certs"]
else:
verify_certs = False

# Authentication
if config is not None and "es_user" in config:
auth = (config["es_user"], config["es_password"])
else:
# Default to admin:admin
auth = None

headers = None
# base64 authentication
if config is not None and "es_encoded_base64" in config and "es_user" in config and "es_password" in config:
if config["es_encoded_base64"]:
encoded_credentials = base64.b64encode(
(config["es_user"] + ":" + config["es_password"]).encode("utf-8")
).decode("utf-8")
headers = {
'Authorization': 'Basic ' + encoded_credentials
}
# remove auth from config
auth = None

# custom headers
if config is not None and "es_headers" in config:
headers = config["es_headers"]

if config is not None and "es_timeout" in config:
timeout = config["es_timeout"]
else:
timeout = 60

if config is not None and "es_max_retries" in config:
max_retries = config["es_max_retries"]
else:
max_retries = 10

if es_urls is not None:
# Initialize the OpenSearch client by passing a list of URLs
self.client = OpenSearch(
hosts=[es_urls],
http_compress=True,
use_ssl=ssl,
verify_certs=verify_certs,
timeout=timeout,
max_retries=max_retries,
retry_on_timeout=True,
http_auth=auth,
headers=headers
)
else:
# Initialize the OpenSearch client by passing a host and port
self.client = OpenSearch(
hosts=[{'host': host, 'port': port}],
http_compress=True,
use_ssl=ssl,
verify_certs=verify_certs,
timeout=timeout,
max_retries=max_retries,
retry_on_timeout=True,
http_auth=auth,
headers=headers
)

# 执行一个简单的查询来检查连接
try:
info = self.client.info()
print('OpenSearch cluster info:', info)
except Exception as e:
print('Error connecting to OpenSearch cluster:', e)

# Create the indices if they don't exist
# self.create_index()

def create_index(self):
for index in [self.document_index, self.ddl_index, self.question_sql_index]:
try:
self.client.indices.create(index)
except Exception as e:
print("Error creating index: ", e)
print(f"opensearch index {index} already exists")
pass

def add_ddl(self, ddl: str, **kwargs) -> str:
# Assuming that you have a DDL index in your OpenSearch
id = str(uuid.uuid4()) + "-ddl"
ddl_dict = {
"ddl": ddl
}
response = self.client.index(index=self.ddl_index, body=ddl_dict, id=id,
**kwargs)
return response['_id']

def add_documentation(self, doc: str, **kwargs) -> str:
# Assuming you have a documentation index in your OpenSearch
id = str(uuid.uuid4()) + "-doc"
doc_dict = {
"doc": doc
}
response = self.client.index(index=self.document_index, id=id,
body=doc_dict, **kwargs)
return response['_id']

def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
# Assuming you have a Questions and SQL index in your OpenSearch
id = str(uuid.uuid4()) + "-sql"
question_sql_dict = {
"question": question,
"sql": sql
}
response = self.client.index(index=self.question_sql_index,
body=question_sql_dict, id=id,
**kwargs)
return response['_id']

def get_related_ddl(self, question: str, **kwargs) -> List[str]:
# Assume you have some vector search mechanism associated with your data
query = {
"query": {
"match": {
"ddl": question
}
}
}
response = self.client.search(index=self.ddl_index, body=query,
**kwargs)
return [hit['_source']['ddl'] for hit in response['hits']['hits']]

def get_related_documentation(self, question: str, **kwargs) -> List[str]:
query = {
"query": {
"match": {
"doc": question
}
}
}
response = self.client.search(index=self.document_index,
body=query,
**kwargs)
return [hit['_source']['doc'] for hit in response['hits']['hits']]

def get_similar_question_sql(self, question: str, **kwargs) -> List[str]:
query = {
"query": {
"match": {
"question": question
}
}
}
response = self.client.search(index=self.question_sql_index,
body=query,
**kwargs)
return [(hit['_source']['question'], hit['_source']['sql']) for hit in
response['hits']['hits']]

def get_training_data(self, **kwargs) -> pd.DataFrame:
# This will be a simple example pulling all data from an index
# WARNING: Do not use this approach in production for large indices!
data = []
response = self.client.search(
index=self.document_index,
body={"query": {"match_all": {}}},
size=1000
)
# records = [hit['_source'] for hit in response['hits']['hits']]
for hit in response['hits']['hits']:
data.append(
{
"id": hit["_id"],
"training_data_type": "documentation",
"question": "",
"content": hit["_source"]['doc'],
}
)

response = self.client.search(
index=self.question_sql_index,
body={"query": {"match_all": {}}},
size=1000
)
# records = [hit['_source'] for hit in response['hits']['hits']]
for hit in response['hits']['hits']:
data.append(
{
"id": hit["_id"],
"training_data_type": "sql",
"question": hit.get("_source", {}).get("question", ""),
"content": hit.get("_source", {}).get("sql", ""),
}
)

response = self.client.search(
index=self.ddl_index,
body={"query": {"match_all": {}}},
size=1000
)
# records = [hit['_source'] for hit in response['hits']['hits']]
for hit in response['hits']['hits']:
data.append(
{
"id": hit["_id"],
"training_data_type": "ddl",
"question": "",
"content": hit["_source"]['ddl'],
}
)

return pd.DataFrame(data)

def remove_training_data(self, id: str, **kwargs) -> bool:
try:
if id.endswith("-sql"):
self.client.delete(index=self.question_sql_index, id=id)
return True
elif id.endswith("-ddl"):
self.client.delete(index=self.ddl_index, id=id, **kwargs)
return True
elif id.endswith("-doc"):
self.client.delete(index=self.document_index, id=id, **kwargs)
return True
else:
return False
except Exception as e:
print("Error deleting training dataError deleting training data: ", e)
return False

def generate_embedding(self, data: str, **kwargs) -> list[float]:
# opensearch doesn't need to generate embeddings
pass


# OpenSearch_VectorStore.__init__(self, config={'es_urls':
# "https://opensearch-node.test.com:9200", 'es_encoded_base64': True, 'es_user':
# "admin", 'es_password': "admin", 'es_verify_certs': True})


# OpenSearch_VectorStore.__init__(self, config={'es_host':
# "https://opensearch-node.test.com", 'es_port': 9200, 'es_user': "admin",
# 'es_password': "admin", 'es_verify_certs': True})
2 changes: 2 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def test_regular_imports():
from vanna.ollama.ollama import Ollama
from vanna.openai.openai_chat import OpenAI_Chat
from vanna.openai.openai_embeddings import OpenAI_Embeddings
from vanna.opensearch.opensearch_vector import OpenSearch_VectorStore
from vanna.remote import VannaDefault
from vanna.vannadb.vannadb_vector import VannaDB_VectorStore
from vanna.ZhipuAI.ZhipuAI_Chat import ZhipuAI_Chat
Expand All @@ -23,6 +24,7 @@ def test_shortcut_imports():
from vanna.mistral import Mistral
from vanna.ollama import Ollama
from vanna.openai import OpenAI_Chat, OpenAI_Embeddings
from vanna.opensearch import OpenSearch_VectorStore
from vanna.vannadb import VannaDB_VectorStore
from vanna.vllm import Vllm
from vanna.ZhipuAI import ZhipuAI_Chat, ZhipuAI_Embeddings

0 comments on commit 4c691e5

Please sign in to comment.