From 812b0327b34d2f4d219446a3040f2065d45aacc4 Mon Sep 17 00:00:00 2001 From: scalabrese Date: Sun, 14 Apr 2024 11:00:53 +0200 Subject: [PATCH 1/9] [Scalabrese] Improved Ollama integration --- pyproject.toml | 7 +- src/vanna/ollama/ollama.py | 155 ++++++++++++++++++++----------------- 2 files changed, 90 insertions(+), 72 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index dcc09dc8..4fb291e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "flit_core.buildapi" name = "vanna" version = "0.3.4" authors = [ - { name="Zain Hoda", email="zain@vanna.ai" }, + { name = "Zain Hoda", email = "zain@vanna.ai" }, ] description = "Generate SQL queries from natural language" @@ -18,7 +18,7 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask", "sqlalchemy" + "requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask", "sqlalchemy", "httpx" ] [project.urls] @@ -31,7 +31,7 @@ mysql = ["PyMySQL"] bigquery = ["google-cloud-bigquery"] snowflake = ["snowflake-connector-python"] duckdb = ["duckdb"] -all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo"] +all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "ollama"] test = ["tox"] chromadb = ["chromadb"] openai = ["openai"] @@ -40,3 +40,4 @@ anthropic = ["anthropic"] gemini = ["google-generativeai"] marqo = ["marqo"] zhipuai = ["zhipuai"] +ollama = ["ollama"] diff --git a/src/vanna/ollama/ollama.py b/src/vanna/ollama/ollama.py index 47bf602d..2ca1fa60 100644 --- a/src/vanna/ollama/ollama.py +++ b/src/vanna/ollama/ollama.py @@ -1,76 +1,93 @@ +import json import re -import requests +from httpx import Timeout from ..base import VannaBase +from ..exceptions import DependencyError class Ollama(VannaBase): - def __init__(self, config=None): - if config is None or "ollama_host" not in config: - self.host = "http://localhost:11434" - else: - self.host = config["ollama_host"] - - if config is None or "model" not in config: - raise ValueError("config must contain a Ollama model") - else: - self.model = config["model"] - - def system_message(self, message: str) -> any: - return {"role": "system", "content": message} - - def user_message(self, message: str) -> any: - return {"role": "user", "content": message} - - def assistant_message(self, message: str) -> any: - return {"role": "assistant", "content": message} - - def extract_sql_query(self, text): - """ - Extracts the first SQL statement after the word 'select', ignoring case, - matches until the first semicolon, three backticks, or the end of the string, - and removes three backticks if they exist in the extracted string. - - Args: - - text (str): The string to search within for an SQL statement. - - Returns: - - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found. - """ - # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string - pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL) - - match = pattern.search(text) - if match: - # Remove three backticks from the matched string if they exist - return match.group(0).replace("```", "") - else: - return text - - def generate_sql(self, question: str, **kwargs) -> str: - # Use the super generate_sql - sql = super().generate_sql(question, **kwargs) - - # Replace "\_" with "_" - sql = sql.replace("\\_", "_") - - sql = sql.replace("\\", "") - - return self.extract_sql_query(sql) - - def submit_prompt(self, prompt, **kwargs) -> str: - url = f"{self.host}/api/chat" - data = { - "model": self.model, - "stream": False, - "messages": prompt, - } - - response = requests.post(url, json=data) - - response_dict = response.json() - - self.log(response.text) - - return response_dict["message"]["content"] + def __init__(self, config=None): + + try: + ollama = __import__("ollama") + except ImportError: + raise DependencyError( + "You need to install required dependencies to execute this method, run command:" + " \npip install ollama" + ) + + if not config: + raise ValueError("config must contain at least Ollama model") + if 'model' not in config.keys(): + raise ValueError("config must contain at least Ollama model") + self.host = config.get("ollama_host", "http://localhost:11434") + self.model = config["model"] + + self.ollama_client = ollama.Client(self.host, timeout=Timeout(240.0)) + self.keep_alive = config.get('keep_alive', None) + self.ollama_options = config.get('options', {}) + self.num_ctx = self.ollama_options.get('num_ctx', 2048) + + def system_message(self, message: str) -> any: + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + return {"role": "assistant", "content": message} + + def extract_sql(self, llm_response): + """ + Extracts the first SQL statement after the word 'select', ignoring case, + matches until the first semicolon, three backticks, or the end of the string, + and removes three backticks if they exist in the extracted string. + + Args: + - llm_response (str): The string to search within for an SQL statement. + + Returns: + - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found. + """ + # Remove ollama-generated extra characters + llm_response = llm_response.replace("\\_", "_") + llm_response = llm_response.replace("\\", "") + + # Regular expression to find 'select, with and ```sql' (ignoring case) and capture until ';', '```', [ (this happens in case of mistral) or end of string + pattern = re.compile(r'(?:select|with|```sql).*?(?:;|```|(\[)|$)', + re.IGNORECASE | re.DOTALL) + + match = pattern.search(llm_response) + if match: + # Remove three backticks from the matched string if they exist + return match.group(0).replace("```", "") + else: + return llm_response + + def __pull_model_if_ne(self, ): + model_response = self.ollama_client.list() + model_lists = [model_element['model'] for model_element in + model_response.get('models', [])] + if self.model not in model_lists: + self.log(f"Pulling model {self.model}....") + self.ollama_client.pull(self.model) + + def submit_prompt(self, prompt, **kwargs) -> str: + self.log( + f"Ollama parameters:\n" + f"model={self.model},\n" + f"options={self.ollama_options},\n" + f"keep_alive={self.keep_alive}") + self.log(f"Prompt Content:\n{json.dumps(prompt)}") + self.__pull_model_if_ne() + response_dict = self.ollama_client.chat(model=self.model, + messages=prompt, + stream=False, + options=self.ollama_options, + keep_alive=self.keep_alive) + + self.log(str(response_dict)) + + return response_dict['message']['content'] From 449cafa55974f7593dd8cf74100713d3e23b0f34 Mon Sep 17 00:00:00 2001 From: scalabrese Date: Sun, 14 Apr 2024 11:05:33 +0200 Subject: [PATCH 2/9] [Scalabrese] Improved Ollama integration --- src/vanna/ollama/ollama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vanna/ollama/ollama.py b/src/vanna/ollama/ollama.py index 2ca1fa60..02719611 100644 --- a/src/vanna/ollama/ollama.py +++ b/src/vanna/ollama/ollama.py @@ -88,6 +88,6 @@ def submit_prompt(self, prompt, **kwargs) -> str: options=self.ollama_options, keep_alive=self.keep_alive) - self.log(str(response_dict)) + self.log(f"Ollama Response:\n{str(response_dict)}") return response_dict['message']['content'] From 7ebc77b541090bdb3c716732b2e088b307705175 Mon Sep 17 00:00:00 2001 From: scalabreseGD <47219719+scalabreseGD@users.noreply.github.com> Date: Mon, 15 Apr 2024 12:53:47 +0200 Subject: [PATCH 3/9] Update ollama.py --- src/vanna/ollama/ollama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vanna/ollama/ollama.py b/src/vanna/ollama/ollama.py index 02719611..73cdbde8 100644 --- a/src/vanna/ollama/ollama.py +++ b/src/vanna/ollama/ollama.py @@ -56,7 +56,7 @@ def extract_sql(self, llm_response): llm_response = llm_response.replace("\\", "") # Regular expression to find 'select, with and ```sql' (ignoring case) and capture until ';', '```', [ (this happens in case of mistral) or end of string - pattern = re.compile(r'(?:select|with|```sql).*?(?:;|```|(\[)|$)', + pattern = re.compile(r'(?:select|with|```sql).*?(?=;|\[|```|$)', re.IGNORECASE | re.DOTALL) match = pattern.search(llm_response) From 77c5e75b5db8db7de7e11244eb77285646277f58 Mon Sep 17 00:00:00 2001 From: scalabrese Date: Tue, 16 Apr 2024 10:34:02 +0200 Subject: [PATCH 4/9] [Scalabrese] Improved Ollama integration --- src/vanna/ollama/ollama.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/vanna/ollama/ollama.py b/src/vanna/ollama/ollama.py index 73cdbde8..9a94c16d 100644 --- a/src/vanna/ollama/ollama.py +++ b/src/vanna/ollama/ollama.py @@ -55,14 +55,19 @@ def extract_sql(self, llm_response): llm_response = llm_response.replace("\\_", "_") llm_response = llm_response.replace("\\", "") - # Regular expression to find 'select, with and ```sql' (ignoring case) and capture until ';', '```', [ (this happens in case of mistral) or end of string - pattern = re.compile(r'(?:select|with|```sql).*?(?=;|\[|```|$)', - re.IGNORECASE | re.DOTALL) - - match = pattern.search(llm_response) - if match: - # Remove three backticks from the matched string if they exist - return match.group(0).replace("```", "") + # Regular expression to find ```sql' and capture until '```' + sql = re.search(r"```sql\n(.*)```", llm_response, re.DOTALL) + # Regular expression to find 'select, with (ignoring case) and capture until ';', [ (this happens in case of mistral) or end of string + select_with = re.search(r'(?:select|with).*?(?=;|\[|$)', llm_response, + re.IGNORECASE | re.DOTALL) + if sql: + self.log( + f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}") + return sql.group(1).replace("```", "") + elif select_with: + self.log( + f"Output from LLM: {llm_response} \nExtracted SQL: {select_with.group(1)}") + return select_with.group(1) else: return llm_response From 26d8f50943f679846dbad29be3c86e70fc2dbe8e Mon Sep 17 00:00:00 2001 From: scalabrese Date: Tue, 16 Apr 2024 10:36:13 +0200 Subject: [PATCH 5/9] [Scalabrese] Made httpx dependency for ollama only --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4fb291e8..6391bb0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask", "sqlalchemy", "httpx" + "requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask", "sqlalchemy" ] [project.urls] @@ -40,4 +40,4 @@ anthropic = ["anthropic"] gemini = ["google-generativeai"] marqo = ["marqo"] zhipuai = ["zhipuai"] -ollama = ["ollama"] +ollama = ["ollama", "httpx"] From cf3e9a7318230cc805dfdae3aeeedced659e5cc0 Mon Sep 17 00:00:00 2001 From: scalabrese Date: Tue, 16 Apr 2024 10:36:48 +0200 Subject: [PATCH 6/9] [Scalabrese] Made httpx dependency for ollama only --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6391bb0b..24aa6400 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ mysql = ["PyMySQL"] bigquery = ["google-cloud-bigquery"] snowflake = ["snowflake-connector-python"] duckdb = ["duckdb"] -all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "ollama"] +all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "ollama", "httpx"] test = ["tox"] chromadb = ["chromadb"] openai = ["openai"] From 8ce00a44b6bdda4a77c9f44893bb03c772e82d58 Mon Sep 17 00:00:00 2001 From: scalabreseGD <47219719+scalabreseGD@users.noreply.github.com> Date: Tue, 16 Apr 2024 14:42:17 +0200 Subject: [PATCH 7/9] Update ollama.py --- src/vanna/ollama/ollama.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/vanna/ollama/ollama.py b/src/vanna/ollama/ollama.py index 9a94c16d..e9926efb 100644 --- a/src/vanna/ollama/ollama.py +++ b/src/vanna/ollama/ollama.py @@ -61,15 +61,15 @@ def extract_sql(self, llm_response): select_with = re.search(r'(?:select|with).*?(?=;|\[|$)', llm_response, re.IGNORECASE | re.DOTALL) if sql: - self.log( - f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}") - return sql.group(1).replace("```", "") + self.log( + f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}") + return sql.group(1).replace("```", "") elif select_with: - self.log( - f"Output from LLM: {llm_response} \nExtracted SQL: {select_with.group(1)}") - return select_with.group(1) + self.log( + f"Output from LLM: {llm_response} \nExtracted SQL: {select_with.group(0)}") + return select_with.group(0) else: - return llm_response + return llm_response def __pull_model_if_ne(self, ): model_response = self.ollama_client.list() From 3840c278d60b4512f9dc0ce1c301848cece94bc7 Mon Sep 17 00:00:00 2001 From: scalabreseGD <47219719+scalabreseGD@users.noreply.github.com> Date: Thu, 18 Apr 2024 19:21:48 +0200 Subject: [PATCH 8/9] Update ollama.py --- src/vanna/ollama/ollama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/vanna/ollama/ollama.py b/src/vanna/ollama/ollama.py index e9926efb..226dc6ba 100644 --- a/src/vanna/ollama/ollama.py +++ b/src/vanna/ollama/ollama.py @@ -56,9 +56,9 @@ def extract_sql(self, llm_response): llm_response = llm_response.replace("\\", "") # Regular expression to find ```sql' and capture until '```' - sql = re.search(r"```sql\n(.*)```", llm_response, re.DOTALL) + sql = re.search(r"```sql\n((.|\n)*?)(?=;|\[|```)", llm_response, re.DOTALL) # Regular expression to find 'select, with (ignoring case) and capture until ';', [ (this happens in case of mistral) or end of string - select_with = re.search(r'(?:select|with).*?(?=;|\[|$)', llm_response, + select_with = re.search(r'(select|(with.*?as \())(.*?)(?=;|\[|```)', llm_response, re.IGNORECASE | re.DOTALL) if sql: self.log( From 20ab1f7ea869b2cd98e2a9e4d4948096d7181500 Mon Sep 17 00:00:00 2001 From: scalabrese Date: Sat, 27 Apr 2024 17:05:34 +0200 Subject: [PATCH 9/9] [Scalabrese] Made pull model only at the startup --- src/vanna/ollama/ollama.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/src/vanna/ollama/ollama.py b/src/vanna/ollama/ollama.py index 226dc6ba..ecf27603 100644 --- a/src/vanna/ollama/ollama.py +++ b/src/vanna/ollama/ollama.py @@ -24,11 +24,22 @@ def __init__(self, config=None): raise ValueError("config must contain at least Ollama model") self.host = config.get("ollama_host", "http://localhost:11434") self.model = config["model"] + if ":" in self.model: + self.model += ":latest" self.ollama_client = ollama.Client(self.host, timeout=Timeout(240.0)) self.keep_alive = config.get('keep_alive', None) self.ollama_options = config.get('options', {}) self.num_ctx = self.ollama_options.get('num_ctx', 2048) + self.__pull_model_if_ne(self.ollama_client, self.model) + + @staticmethod + def __pull_model_if_ne(ollama_client, model): + model_response = ollama_client.list() + model_lists = [model_element['model'] for model_element in + model_response.get('models', [])] + if model not in model_lists: + ollama_client.pull(model) def system_message(self, message: str) -> any: return {"role": "system", "content": message} @@ -58,26 +69,19 @@ def extract_sql(self, llm_response): # Regular expression to find ```sql' and capture until '```' sql = re.search(r"```sql\n((.|\n)*?)(?=;|\[|```)", llm_response, re.DOTALL) # Regular expression to find 'select, with (ignoring case) and capture until ';', [ (this happens in case of mistral) or end of string - select_with = re.search(r'(select|(with.*?as \())(.*?)(?=;|\[|```)', llm_response, + select_with = re.search(r'(select|(with.*?as \())(.*?)(?=;|\[|```)', + llm_response, re.IGNORECASE | re.DOTALL) if sql: - self.log( - f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}") - return sql.group(1).replace("```", "") + self.log( + f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}") + return sql.group(1).replace("```", "") elif select_with: - self.log( - f"Output from LLM: {llm_response} \nExtracted SQL: {select_with.group(0)}") - return select_with.group(0) + self.log( + f"Output from LLM: {llm_response} \nExtracted SQL: {select_with.group(0)}") + return select_with.group(0) else: - return llm_response - - def __pull_model_if_ne(self, ): - model_response = self.ollama_client.list() - model_lists = [model_element['model'] for model_element in - model_response.get('models', [])] - if self.model not in model_lists: - self.log(f"Pulling model {self.model}....") - self.ollama_client.pull(self.model) + return llm_response def submit_prompt(self, prompt, **kwargs) -> str: self.log( @@ -86,7 +90,6 @@ def submit_prompt(self, prompt, **kwargs) -> str: f"options={self.ollama_options},\n" f"keep_alive={self.keep_alive}") self.log(f"Prompt Content:\n{json.dumps(prompt)}") - self.__pull_model_if_ne() response_dict = self.ollama_client.chat(model=self.model, messages=prompt, stream=False,