Skip to content

Commit

Permalink
False replaced with exception
Browse files Browse the repository at this point in the history
  • Loading branch information
arslanhashmi committed Aug 4, 2023
1 parent b4c2410 commit 420ccbf
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 55 deletions.
91 changes: 54 additions & 37 deletions src/vanna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def set_model(model: str):
_set_org(org=model)


def add_sql(question: str, sql: str, tag: Union[str, None] = "Manually Trained") -> bool:
def add_sql(question: str, sql: str, tag: Union[str, None] = "Manually Trained") -> Union[bool, str]:
"""
Adds a question and its corresponding SQL query to the model's training data
Expand Down Expand Up @@ -440,15 +440,12 @@ def add_sql(question: str, sql: str, tag: Union[str, None] = "Manually Trained")

d = __rpc_call(method="store_sql", params=params)

if 'result' not in d:
return False

status = Status(**d['result'])

return status.success
return status.success, status.message


def add_ddl(ddl: str) -> bool:
def add_ddl(ddl: str) -> Union[bool, str]:
"""
Adds a DDL statement to the model's training data
Expand All @@ -469,15 +466,12 @@ def add_ddl(ddl: str) -> bool:

d = __rpc_call(method="store_ddl", params=params)

if 'result' not in d:
return False

status = Status(**d['result'])

return status.success
return status.success, status.message


def add_documentation(documentation: str) -> bool:
def add_documentation(documentation: str) -> Union[bool, str]:
"""
Adds documentation to the model's training data
Expand All @@ -498,12 +492,9 @@ def add_documentation(documentation: str) -> bool:

d = __rpc_call(method="store_documentation", params=params)

if 'result' not in d:
return False

status = Status(**d['result'])

return status.success
return status.success, status.message

@dataclass
class TrainingPlanItem:
Expand Down Expand Up @@ -840,27 +831,42 @@ def train(question: str = None, sql: str = None, ddl: str = None, documentation:

if documentation:
print("Adding documentation....")
return add_documentation(documentation)
success, error_message = add_documentation(sql)
if not success:
raise APIError(error_message)

return success

if sql:
if question is None:
question = generate_question(sql)
print("Question generated with sql:", Question, '\nAdding SQL...')
return add_sql(question=question, sql=sql)

success, error_message = add_sql(question=question, sql=sql)
if not success:
raise APIError(error_message)

return success

if ddl:
print("Adding ddl:", ddl)
return add_ddl(sql)
success, error_message = add_ddl(sql)
if not success:
raise APIError(error_message)

return success

if json_file:
validate_config_path(json_file)
with open(json_file, 'r') as js_file:
data = json.load(js_file)
print("Adding Questions And SQLs using file:", json_file)
for question in data:
if not add_sql(question=question['question'], sql=question['answer']):
success, error_message = add_sql(question=question['question'], sql=question['answer'])
if not success:
print(f"Not able to add sql for question: {question['question']} from {json_file}")
return False
raise APIError(error_message)

return True

if sql_file:
Expand All @@ -869,34 +875,45 @@ def train(question: str = None, sql: str = None, ddl: str = None, documentation:
sql_statements = sqlparse.split(file.read())
for statement in sql_statements:
if 'CREATE TABLE' in statement:
if add_ddl(statement):
print("ddl Added!")
return True
print("Not able to add DDL")
return False
success, error_message = add_ddl(statement)
if not success:
print("Not able to add DDL")
raise APIError(error_message)

print("ddl Added!")
return success

else:
question = generate_question(sql=statement)
if add_sql(question=question, sql=statement):
print("SQL added!")
return True
print("Not able to add sql.")
return False
success, error_message = add_sql(question=question, sql=statement)
if not success:
print("Not able to add sql.")
raise APIError(error_message)

print("SQL added!")
return success

return False

if plan:
for item in plan._plan:
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
if not add_ddl(item.item_value):
success, error_message = add_ddl(item.item_value)
if not success:
print(f"Not able to add ddl for {item.item_group}")
return False
raise APIError(error_message)

elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
if not add_documentation(item.item_value):
success, error_message = add_documentation(item.item_value)
if not success:
print(f"Not able to add documentation for {item.item_group}.{item.item_name}")
return False
raise APIError(error_message)

elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL:
if not add_sql(question=item.item_name, sql=item.item_value):
success, error_message = add_sql(question=item.item_name, sql=item.item_value)
if not success:
print(f"Not able to add sql for {item.item_group}.{item.item_name}")
return False
raise APIError(error_message)


def flag_sql_for_review(question: str, sql: Union[str, None] = None, error_msg: Union[str, None] = None) -> bool:
Expand Down
43 changes: 25 additions & 18 deletions tests/test_vanna.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ def test_generate_followup_questions():
assert questions == ['AI Response']

def test_add_sql():
rv = vn.add_sql(question="What's the data about student John Doe?", sql="SELECT * FROM students WHERE name = 'John Doe'")
assert rv == True
success, _ = vn.add_sql(question="What's the data about student John Doe?", sql="SELECT * FROM students WHERE name = 'John Doe'")
assert success == True

rv = vn.add_sql(question="What's the data about student Jane Doe?", sql="SELECT * FROM students WHERE name = 'Jane Doe'")
assert rv == True
success, _ = vn.add_sql(question="What's the data about student Jane Doe?", sql="SELECT * FROM students WHERE name = 'Jane Doe'")
assert success == True

def test_generate_sql_caching():
rv = vn.generate_sql(question="What's the data about student John Doe?")
Expand All @@ -167,7 +167,9 @@ def test_remove_sql():
rv = vn.remove_sql(question="What's the data about student John Doe?")
assert rv == True

def test_flag_sql():
def test_flag_sql(monkeypatch):
switch_to_user('user1', monkeypatch)
vn.set_model('test-org')
rv = vn.flag_sql_for_review(question="What's the data about student Jane Doe?")
assert rv == True

Expand All @@ -184,30 +186,35 @@ def test_get_all_questions():
# assert rv == AccuracyStats(num_questions=2, data={'No SQL Generated': 2, 'SQL Unable to Run': 0, 'Assumed Correct': 0, 'Flagged for Review': 0, 'Reviewed and Approved': 0, 'Reviewed and Rejected': 0, 'Reviewed and Updated': 0})

def test_add_documentation_fail():
rv = vn.add_documentation(documentation="This is the documentation")
assert rv == False
success, error_message = vn.add_documentation(documentation="This is the documentation")
assert success == False
assert error_message == "Failed to store documentation: User [email protected] is not an admin of organization demo-tpc-h"

def test_add_ddl_pass_fail():
rv = vn.add_ddl(ddl="This is the ddl")
assert rv == False
success, error_message = vn.add_ddl(ddl="This is the ddl")
assert success == False
assert error_message == "Failed to store DDL: User [email protected] is not an admin of organization demo-tpc-h"

def test_add_sql_pass_fail():
rv = vn.add_sql(question="How many students are there?", sql="SELECT * FROM students")
assert rv == False
success, error_message = vn.add_sql(question="How many students are there?", sql="SELECT * FROM students")
assert success == False
assert error_message == "Failed to store question: User [email protected] is not an admin of organization demo-tpc-h"

def test_add_documentation_pass(monkeypatch):
switch_to_user('user1', monkeypatch)
vn.set_model('test-org')
rv = vn.add_documentation(documentation="This is the documentation")
assert rv == True
success, _ = vn.add_documentation(documentation="This is the documentation")
assert success == True

def test_add_ddl_pass():
rv = vn.add_ddl(ddl="This is the ddl")
assert rv == True
success, _ = vn.add_ddl(ddl="This is the ddl")
assert success == True

def test_add_sql_pass():
rv = vn.add_sql(question="How many students are there?", sql="SELECT * FROM students")
assert rv == True
def test_add_sql_pass(monkeypatch):
switch_to_user('user1', monkeypatch)
vn.set_model('test-org')
success, _ = vn.add_sql(question="How many students are there?", sql="SELECT * FROM students")
assert success == True

num_training_data = 4

Expand Down

0 comments on commit 420ccbf

Please sign in to comment.