Skip to content

Commit

Permalink
Merge pull request #98 from vanna-ai/get_related_training_data
Browse files Browse the repository at this point in the history
get related training data
  • Loading branch information
zainhoda authored Aug 16, 2023
2 parents d9c631e + 309f244 commit e360a13
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 2 deletions.
28 changes: 27 additions & 1 deletion src/vanna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@

from .types import SQLAnswer, Explanation, QuestionSQLPair, Question, QuestionId, DataResult, PlotlyResult, Status, \
FullQuestionDocument, QuestionList, QuestionCategory, AccuracyStats, UserEmail, UserOTP, ApiKey, OrganizationList, \
Organization, NewOrganization, StringData, QuestionStringList, Visibility, NewOrganizationMember, DataFrameJSON
Organization, NewOrganization, StringData, QuestionStringList, Visibility, NewOrganizationMember, DataFrameJSON, TrainingData
from typing import List, Union, Callable, Tuple
from .exceptions import ImproperlyConfigured, DependencyError, ConnectionError, OTPCodeError, SQLRemoveError, \
ValidationError, APIError
Expand Down Expand Up @@ -1043,6 +1043,32 @@ def generate_sql(question: str) -> str:

return sql_answer.sql

def get_related_training_data(question: str) -> TrainingData:
"""
**Example:**
```python
training_data = vn.get_related_training_data(question="What is the average salary of employees?")
```
Get the training data related to a question.
Args:
question (str): The question to get related training data for.
Returns:
TrainingData or None: The related training data, or None if an error occurred.
"""
params = [Question(question=question)]

d = __rpc_call(method="get_related_training_data", params=params)

if 'result' not in d:
return None

# Load the result into a dataclass
training_data = TrainingData(**d['result'])

return training_data

def generate_meta(question: str) -> str:
"""
Expand Down
8 changes: 7 additions & 1 deletion src/vanna/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,10 @@ class StringData:

@dataclass
class DataFrameJSON:
data: str
data: str

@dataclass
class TrainingData:
questions: List[dict]
ddl: List[str]
documentation: List[str]
6 changes: 6 additions & 0 deletions tests/test_vanna.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ def test_double_train():
training_data = vn.get_training_data()
assert training_data.shape == (1, 4)

def test_get_related_training_data():
data = vn.get_related_training_data(question="What's the data about student John Doe?")
assert data.questions[0]['question'] == 'What is the total sales for each product?'
assert data.questions[0]['sql'] == 'SELECT * FROM ...'
assert data.ddl == ['DDL here']
assert data.documentation == ['Documentation here']

@pytest.mark.parametrize("params", [
dict(
Expand Down

0 comments on commit e360a13

Please sign in to comment.