-
Notifications
You must be signed in to change notification settings - Fork 966
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b4c2410
commit 420ccbf
Showing
2 changed files
with
79 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -152,11 +152,11 @@ def test_generate_followup_questions(): | |
assert questions == ['AI Response'] | ||
|
||
def test_add_sql(): | ||
rv = vn.add_sql(question="What's the data about student John Doe?", sql="SELECT * FROM students WHERE name = 'John Doe'") | ||
assert rv == True | ||
success, _ = vn.add_sql(question="What's the data about student John Doe?", sql="SELECT * FROM students WHERE name = 'John Doe'") | ||
assert success == True | ||
|
||
rv = vn.add_sql(question="What's the data about student Jane Doe?", sql="SELECT * FROM students WHERE name = 'Jane Doe'") | ||
assert rv == True | ||
success, _ = vn.add_sql(question="What's the data about student Jane Doe?", sql="SELECT * FROM students WHERE name = 'Jane Doe'") | ||
assert success == True | ||
|
||
def test_generate_sql_caching(): | ||
rv = vn.generate_sql(question="What's the data about student John Doe?") | ||
|
@@ -167,7 +167,9 @@ def test_remove_sql(): | |
rv = vn.remove_sql(question="What's the data about student John Doe?") | ||
assert rv == True | ||
|
||
def test_flag_sql(): | ||
def test_flag_sql(monkeypatch): | ||
switch_to_user('user1', monkeypatch) | ||
vn.set_model('test-org') | ||
rv = vn.flag_sql_for_review(question="What's the data about student Jane Doe?") | ||
assert rv == True | ||
|
||
|
@@ -184,30 +186,35 @@ def test_get_all_questions(): | |
# assert rv == AccuracyStats(num_questions=2, data={'No SQL Generated': 2, 'SQL Unable to Run': 0, 'Assumed Correct': 0, 'Flagged for Review': 0, 'Reviewed and Approved': 0, 'Reviewed and Rejected': 0, 'Reviewed and Updated': 0}) | ||
|
||
def test_add_documentation_fail(): | ||
rv = vn.add_documentation(documentation="This is the documentation") | ||
assert rv == False | ||
success, error_message = vn.add_documentation(documentation="This is the documentation") | ||
assert success == False | ||
assert error_message == "Failed to store documentation: User [email protected] is not an admin of organization demo-tpc-h" | ||
|
||
def test_add_ddl_pass_fail(): | ||
rv = vn.add_ddl(ddl="This is the ddl") | ||
assert rv == False | ||
success, error_message = vn.add_ddl(ddl="This is the ddl") | ||
assert success == False | ||
assert error_message == "Failed to store DDL: User [email protected] is not an admin of organization demo-tpc-h" | ||
|
||
def test_add_sql_pass_fail(): | ||
rv = vn.add_sql(question="How many students are there?", sql="SELECT * FROM students") | ||
assert rv == False | ||
success, error_message = vn.add_sql(question="How many students are there?", sql="SELECT * FROM students") | ||
assert success == False | ||
assert error_message == "Failed to store question: User [email protected] is not an admin of organization demo-tpc-h" | ||
|
||
def test_add_documentation_pass(monkeypatch): | ||
switch_to_user('user1', monkeypatch) | ||
vn.set_model('test-org') | ||
rv = vn.add_documentation(documentation="This is the documentation") | ||
assert rv == True | ||
success, _ = vn.add_documentation(documentation="This is the documentation") | ||
assert success == True | ||
|
||
def test_add_ddl_pass(): | ||
rv = vn.add_ddl(ddl="This is the ddl") | ||
assert rv == True | ||
success, _ = vn.add_ddl(ddl="This is the ddl") | ||
assert success == True | ||
|
||
def test_add_sql_pass(): | ||
rv = vn.add_sql(question="How many students are there?", sql="SELECT * FROM students") | ||
assert rv == True | ||
def test_add_sql_pass(monkeypatch): | ||
switch_to_user('user1', monkeypatch) | ||
vn.set_model('test-org') | ||
success, _ = vn.add_sql(question="How many students are there?", sql="SELECT * FROM students") | ||
assert success == True | ||
|
||
num_training_data = 4 | ||
|
||
|