Skip to content

Commit

Permalink
Merge pull request #361 from scalabreseGD/ollama-fix-function
Browse files Browse the repository at this point in the history
[Scalabrese] Improved Ollama integration
  • Loading branch information
zainhoda authored Apr 30, 2024
2 parents 6787539 + 3358e4d commit 0703021
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 70 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
duckdb = ["duckdb"]
google = ["google-generativeai", "google-cloud-aiplatform"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx"]
test = ["tox"]
chromadb = ["chromadb"]
openai = ["openai"]
Expand All @@ -41,5 +41,6 @@ anthropic = ["anthropic"]
gemini = ["google-generativeai"]
marqo = ["marqo"]
zhipuai = ["zhipuai"]
ollama = ["ollama", "httpx"]
qdrant = ["qdrant-client"]
vllm = ["vllm"]
163 changes: 94 additions & 69 deletions src/vanna/ollama/ollama.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,101 @@
import json
import re

import requests
from httpx import Timeout

from ..base import VannaBase
from ..exceptions import DependencyError


class Ollama(VannaBase):
def __init__(self, config=None):
if config is None or "ollama_host" not in config:
self.host = "http://localhost:11434"
else:
self.host = config["ollama_host"]

if config is None or "model" not in config:
raise ValueError("config must contain a Ollama model")
else:
self.model = config["model"]

def system_message(self, message: str) -> any:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}

def extract_sql_query(self, text):
"""
Extracts the first SQL statement after the word 'select', ignoring case,
matches until the first semicolon, three backticks, or the end of the string,
and removes three backticks if they exist in the extracted string.
Args:
- text (str): The string to search within for an SQL statement.
Returns:
- str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
"""
# Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)

match = pattern.search(text)
if match:
# Remove three backticks from the matched string if they exist
return match.group(0).replace("```", "")
else:
return text

def generate_sql(self, question: str, **kwargs) -> str:
# Use the super generate_sql
sql = super().generate_sql(question, **kwargs)

# Replace "\_" with "_"
sql = sql.replace("\\_", "_")

sql = sql.replace("\\", "")

return self.extract_sql_query(sql)

def submit_prompt(self, prompt, **kwargs) -> str:
url = f"{self.host}/api/chat"
data = {
"model": self.model,
"stream": False,
"messages": prompt,
}

response = requests.post(url, json=data)

response_dict = response.json()

self.log(response.text)

return response_dict["message"]["content"]
def __init__(self, config=None):

try:
ollama = __import__("ollama")
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method, run command:"
" \npip install ollama"
)

if not config:
raise ValueError("config must contain at least Ollama model")
if 'model' not in config.keys():
raise ValueError("config must contain at least Ollama model")
self.host = config.get("ollama_host", "http://localhost:11434")
self.model = config["model"]
if ":" in self.model:
self.model += ":latest"

self.ollama_client = ollama.Client(self.host, timeout=Timeout(240.0))
self.keep_alive = config.get('keep_alive', None)
self.ollama_options = config.get('options', {})
self.num_ctx = self.ollama_options.get('num_ctx', 2048)
self.__pull_model_if_ne(self.ollama_client, self.model)

@staticmethod
def __pull_model_if_ne(ollama_client, model):
model_response = ollama_client.list()
model_lists = [model_element['model'] for model_element in
model_response.get('models', [])]
if model not in model_lists:
ollama_client.pull(model)

def system_message(self, message: str) -> any:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}

def extract_sql(self, llm_response):
"""
Extracts the first SQL statement after the word 'select', ignoring case,
matches until the first semicolon, three backticks, or the end of the string,
and removes three backticks if they exist in the extracted string.
Args:
- llm_response (str): The string to search within for an SQL statement.
Returns:
- str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
"""
# Remove ollama-generated extra characters
llm_response = llm_response.replace("\\_", "_")
llm_response = llm_response.replace("\\", "")

# Regular expression to find ```sql' and capture until '```'
sql = re.search(r"```sql\n((.|\n)*?)(?=;|\[|```)", llm_response, re.DOTALL)
# Regular expression to find 'select, with (ignoring case) and capture until ';', [ (this happens in case of mistral) or end of string
select_with = re.search(r'(select|(with.*?as \())(.*?)(?=;|\[|```)',
llm_response,
re.IGNORECASE | re.DOTALL)
if sql:
self.log(
f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
return sql.group(1).replace("```", "")
elif select_with:
self.log(
f"Output from LLM: {llm_response} \nExtracted SQL: {select_with.group(0)}")
return select_with.group(0)
else:
return llm_response

def submit_prompt(self, prompt, **kwargs) -> str:
self.log(
f"Ollama parameters:\n"
f"model={self.model},\n"
f"options={self.ollama_options},\n"
f"keep_alive={self.keep_alive}")
self.log(f"Prompt Content:\n{json.dumps(prompt)}")
response_dict = self.ollama_client.chat(model=self.model,
messages=prompt,
stream=False,
options=self.ollama_options,
keep_alive=self.keep_alive)

self.log(f"Ollama Response:\n{str(response_dict)}")

return response_dict['message']['content']

0 comments on commit 0703021

Please sign in to comment.