Skip to content

Commit

Permalink
lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
arslanhashmi committed Aug 3, 2023
1 parent 7430069 commit 72909bb
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 63 deletions.
138 changes: 78 additions & 60 deletions src/vanna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@
import sqlparse
from dataclasses import dataclass

from .types import SQLAnswer, Explanation, QuestionSQLPair, Question, QuestionId, DataResult, PlotlyResult, Status, \
FullQuestionDocument, QuestionList, QuestionCategory, AccuracyStats, UserEmail, UserOTP, ApiKey, OrganizationList, \
from .types import SQLAnswer, Explanation, QuestionSQLPair, Question, DataResult, PlotlyResult, Status, \
QuestionCategory, UserEmail, UserOTP, ApiKey, OrganizationList, \
Organization, NewOrganization, StringData, QuestionStringList, Visibility, NewOrganizationMember, DataFrameJSON
from typing import List, Union, Callable, Tuple
from .exceptions import ImproperlyConfigured, DependencyError, ConnectionError, OTPCodeError, SQLRemoveError, \
Expand Down Expand Up @@ -116,6 +116,7 @@
_endpoint = "https://ask.vanna.ai/rpc"
_unauthenticated_endpoint = "https://ask.vanna.ai/unauthenticated_rpc"


def __unauthenticated_rpc_call(method, params):
headers = {
'Content-Type': 'application/json',
Expand Down Expand Up @@ -160,9 +161,11 @@ def __rpc_call(method, params):
response = requests.post(_endpoint, headers=headers, data=json.dumps(data))
return response.json()


def __dataclass_to_dict(obj):
return dataclasses.asdict(obj)


def get_api_key(email: str, otp_code: Union[str, None] = None) -> str:
"""
**Example:**
Expand Down Expand Up @@ -238,7 +241,9 @@ def set_api_key(key: str) -> None:
models = get_models()

if len(models) == 0:
raise ConnectionError("There was an error communicating with the Vanna.AI API. Please try again or contact [email protected]")
raise ConnectionError(
"There was an error communicating with the Vanna.AI API. Please try again or contact [email protected]")


def get_models() -> List[str]:
"""
Expand Down Expand Up @@ -356,6 +361,7 @@ def update_model_visibility(public: bool) -> bool:

return status.success


def _set_org(org: str) -> None:
global __org

Expand Down Expand Up @@ -505,6 +511,7 @@ def add_documentation(documentation: str) -> bool:

return status.success


@dataclass
class TrainingPlanItem:
item_type: str
Expand Down Expand Up @@ -544,7 +551,7 @@ def __init__(self, plan: List[TrainingPlanItem]):

def __str__(self):
return "\n".join(self.get_summary())

def __repr__(self):
return self.__str__()

Expand Down Expand Up @@ -584,7 +591,6 @@ def remove_item(self, item: str):
self._plan.remove(plan_item)
break



def __get_databases() -> List[str]:
try:
Expand All @@ -594,16 +600,20 @@ def __get_databases() -> List[str]:
df_databases = run_sql("SHOW DATABASES")
except:
return []

return df_databases['DATABASE_NAME'].unique().tolist()


def __get_information_schema_tables(database: str) -> pd.DataFrame:
df_tables = run_sql(f'SELECT * FROM {database}.INFORMATION_SCHEMA.TABLES')

return df_tables


def get_training_plan_experimental(filter_databases: Union[List[str], None] = None, filter_schemas: Union[List[str], None] = None, include_information_schema: bool = False, use_historical_queries: bool = True) -> TrainingPlan:
def get_training_plan_experimental(filter_databases: Union[List[str], None] = None,
filter_schemas: Union[List[str], None] = None,
include_information_schema: bool = False,
use_historical_queries: bool = True) -> TrainingPlan:
"""
**EXPERIMENTAL** : This method is experimental and may change in future versions.
Expand All @@ -625,15 +635,18 @@ def get_training_plan_experimental(filter_databases: Union[List[str], None] = No
if use_historical_queries:
try:
print("Trying query history")
df_history = run_sql(""" select * from table(information_schema.query_history(result_limit => 5000)) order by start_time""")
df_history = run_sql(
""" select * from table(information_schema.query_history(result_limit => 5000)) order by start_time""")

df_history_filtered = df_history.query('ROWS_PRODUCED > 1')
if filter_databases is not None:
mask = df_history_filtered['QUERY_TEXT'].str.lower().apply(lambda x: any(s in x for s in [s.lower() for s in filter_databases]))
mask = df_history_filtered['QUERY_TEXT'].str.lower().apply(
lambda x: any(s in x for s in [s.lower() for s in filter_databases]))
df_history_filtered = df_history_filtered[mask]

if filter_schemas is not None:
mask = df_history_filtered['QUERY_TEXT'].str.lower().apply(lambda x: any(s in x for s in [s.lower() for s in filter_schemas]))
mask = df_history_filtered['QUERY_TEXT'].str.lower().apply(
lambda x: any(s in x for s in [s.lower() for s in filter_schemas]))
df_history_filtered = df_history_filtered[mask]

for query in df_history_filtered.sample(10)['QUERY_TEXT'].unique().tolist():
Expand All @@ -648,7 +661,7 @@ def get_training_plan_experimental(filter_databases: Union[List[str], None] = No
print(e)

databases = __get_databases()

for database in databases:
if filter_databases is not None and database not in filter_databases:
continue
Expand All @@ -674,15 +687,17 @@ def get_training_plan_experimental(filter_databases: Union[List[str], None] = No
for table in tables:
df_columns_filtered_to_table = df_columns_filtered_to_schema.query(f"TABLE_NAME == '{table}'")
doc = f"The following columns are in the {table} table in the {database} database:\n\n"
doc += df_columns_filtered_to_table[["TABLE_CATALOG", "TABLE_SCHEMA", "TABLE_NAME", "COLUMN_NAME", "DATA_TYPE", "COMMENT"]].to_markdown()

doc += df_columns_filtered_to_table[
["TABLE_CATALOG", "TABLE_SCHEMA", "TABLE_NAME", "COLUMN_NAME", "DATA_TYPE",
"COMMENT"]].to_markdown()

plan._plan.append(TrainingPlanItem(
item_type=TrainingPlanItem.ITEM_TYPE_IS,
item_group=f"{database}.{schema}",
item_name=table,
item_value=doc
))

except Exception as e:
print(e)
pass
Expand Down Expand Up @@ -711,36 +726,36 @@ def get_training_plan_experimental(filter_databases: Union[List[str], None] = No
# print("Trying INFORMATION_SCHEMA.TABLES")
# df = run_sql("SELECT * FROM INFORMATION_SCHEMA.TABLES")

# breakpoint()

# try:
# print("Trying SCHEMATA")
# df_schemata = run_sql("SELECT * FROM region-us.INFORMATION_SCHEMA.SCHEMATA")

# for schema in df_schemata.schema_name.unique():
# df = run_sql(f"SELECT * FROM {schema}.information_schema.tables")

# for table in df.table_name.unique():
# plan._plan.append(TrainingPlanItem(
# item_type=TrainingPlanItem.ITEM_TYPE_IS,
# item_group=schema,
# item_name=table,
# item_value=None
# ))

# try:
# ddl_df = run_sql(f"SELECT GET_DDL('schema', '{schema}')")

# plan._plan.append(TrainingPlanItem(
# item_type=TrainingPlanItem.ITEM_TYPE_DDL,
# item_group=schema,
# item_name=None,
# item_value=ddl_df.iloc[0, 0]
# ))
# except:
# pass
# except:
# pass
# breakpoint()

# try:
# print("Trying SCHEMATA")
# df_schemata = run_sql("SELECT * FROM region-us.INFORMATION_SCHEMA.SCHEMATA")

# for schema in df_schemata.schema_name.unique():
# df = run_sql(f"SELECT * FROM {schema}.information_schema.tables")

# for table in df.table_name.unique():
# plan._plan.append(TrainingPlanItem(
# item_type=TrainingPlanItem.ITEM_TYPE_IS,
# item_group=schema,
# item_name=table,
# item_value=None
# ))

# try:
# ddl_df = run_sql(f"SELECT GET_DDL('schema', '{schema}')")

# plan._plan.append(TrainingPlanItem(
# item_type=TrainingPlanItem.ITEM_TYPE_DDL,
# item_group=schema,
# item_name=None,
# item_value=ddl_df.iloc[0, 0]
# ))
# except:
# pass
# except:
# pass

return plan

Expand All @@ -753,7 +768,7 @@ def train(question: str = None, sql: str = None, ddl: str = None, documentation:
vn.train()
```
Train Vanna.AI on a question and its corresponding SQL query.
Train Vanna.AI on a question and its corresponding SQL query.
If you call it with no arguments, it will check if you connected to a database and it will attempt to train on the metadata of that database.
If you call it with the sql argument, it's equivalent to [`add_sql()`][vanna.add_sql].
If you call it with the ddl argument, it's equivalent to [`add_ddl()`][vanna.add_ddl].
Expand Down Expand Up @@ -820,7 +835,7 @@ def train(question: str = None, sql: str = None, ddl: str = None, documentation:
print("Not able to add sql.")
return False
return False

if plan:
for item in plan._plan:
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
Expand Down Expand Up @@ -915,7 +930,7 @@ def remove_sql(question: str) -> bool:
d = __rpc_call(method="remove_sql", params=params)

if 'result' not in d:
raise Exception(f"Error removing SQL")
raise Exception("Error removing SQL")
return False

status = Status(**d['result'])
Expand Down Expand Up @@ -943,7 +958,7 @@ def remove_training_data(id: str) -> bool:
d = __rpc_call(method="remove_training_data", params=params)

if 'result' not in d:
raise APIError(f"Error removing training data")
raise APIError("Error removing training data")

status = Status(**d['result'])

Expand Down Expand Up @@ -1110,11 +1125,11 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai

if print_results:
try:
Code = __import__('IPython.display', fromlist=['Code']).Code
display(Code(sql))
except Exception as e:
Code = __import__('IPython.display', fromlist=['Code']).Code
display(Code(sql))
except Exception:
print(sql)

if run_sql is None:
print("If you want to run the SQL query, provide a vn.run_sql function.")

Expand All @@ -1130,11 +1145,11 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai
try:
display = __import__('IPython.display', fromlist=['display']).display
display(df)
except Exception as e:
except Exception:
print(df)

if len(df) > 0 and auto_train:
add_sql(question=question, sql=sql, tag=types.QuestionCategory.SQL_RAN)
add_sql(question=question, sql=sql, tag=QuestionCategory.SQL_RAN)

try:
plotly_code = generate_plotly_code(question=question, sql=sql, df=df)
Expand All @@ -1145,7 +1160,7 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai
Image = __import__('IPython.display', fromlist=['Image']).Image
img_bytes = fig.to_image(format="png", scale=2)
display(Image(img_bytes))
except Exception as e:
except Exception:
fig.show()

if generate_followups:
Expand All @@ -1159,10 +1174,9 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai
display = __import__('IPython.display', fromlist=['display']).display
Markdown = __import__('IPython.display', fromlist=['Markdown']).Markdown
display(Markdown(md))
except Exception as e:
except Exception:
print(md)


if print_results:
return None
else:
Expand Down Expand Up @@ -1190,7 +1204,8 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai
return sql, None, None, None


def generate_plotly_code(question: Union[str, None], sql: Union[str, None], df: pd.DataFrame, chart_instructions: Union[str, None] = None) -> str:
def generate_plotly_code(question: Union[str, None], sql: Union[str, None], df: pd.DataFrame,
chart_instructions: Union[str, None] = None) -> str:
"""
**Example:**
```python
Expand Down Expand Up @@ -1333,6 +1348,7 @@ def generate_explanation(sql: str) -> str:

return explanation.explanation


def generate_question(sql: str) -> str:
"""
Expand Down Expand Up @@ -1426,6 +1442,7 @@ def get_training_data() -> pd.DataFrame:

return df


def connect_to_sqlite(url: str):
"""
Connect to a SQLite database. This is just a helper function to set [`vn.run_sql`][vanna.run_sql]
Expand Down Expand Up @@ -1458,6 +1475,7 @@ def run_sql_sqlite(sql: str):
global run_sql
run_sql = run_sql_sqlite


def connect_to_snowflake(account: str, username: str, password: str, database: str, role: Union[str, None] = None):
"""
Connect to Snowflake using the Snowflake connector. This is just a helper function to set [`vn.run_sql`][vanna.run_sql]
Expand Down Expand Up @@ -1487,7 +1505,7 @@ def connect_to_snowflake(account: str, username: str, password: str, database: s
snowflake = __import__('snowflake.connector')
except ImportError:
raise DependencyError("You need to install required dependencies to execute this method, run command:"
" \npip install vanna[snowflake]")
" \npip install vanna[snowflake]")

if username == 'my-username':
username_env = os.getenv('SNOWFLAKE_USERNAME')
Expand Down Expand Up @@ -1575,7 +1593,7 @@ def connect_to_postgres(host: str = None, dbname: str = None, user: str = None,
import psycopg2.extras
except ImportError:
raise DependencyError("You need to install required dependencies to execute this method,"
" run command: \npip install vanna[postgres]")
" run command: \npip install vanna[postgres]")

if not host:
host = os.getenv('HOST')
Expand Down
Loading

0 comments on commit 72909bb

Please sign in to comment.