Skip to content

Commit

Permalink
modifying subselect join predictor left only for TS predictor
Browse files Browse the repository at this point in the history
subselect join regular model is proceeded by 'join tables' logic
#262
  • Loading branch information
ea-rus committed Jul 27, 2023
1 parent b4c9f64 commit f7660ba
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 99 deletions.
119 changes: 65 additions & 54 deletions mindsdb_sql/planner/query_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,76 +926,87 @@ def get_aliased_fields(self, targets):
aliased_fields[target.alias.to_string()] = target
return aliased_fields

def plan_join(self, query, integration=None):
def adapt_dbt_query(self, query, integration):
orig_query = query

join = query.from_table
join_left = join.left
join_right = join.right

if isinstance(join_left, Select) and isinstance(join_left.from_table, Identifier):
# dbt query.
# dbt query.

# move latest into subquery
moved_conditions = []
# move latest into subquery
moved_conditions = []

def move_latest(node, **kwargs):
if isinstance(node, BinaryOperation):
if Latest() in node.args:
for arg in node.args:
if isinstance(arg, Identifier):
# remove table alias
arg.parts = [arg.parts[-1]]
moved_conditions.append(node)
def move_latest(node, **kwargs):
if isinstance(node, BinaryOperation):
if Latest() in node.args:
for arg in node.args:
if isinstance(arg, Identifier):
# remove table alias
arg.parts = [arg.parts[-1]]
moved_conditions.append(node)

query_traversal(query.where, move_latest)
query_traversal(query.where, move_latest)

# TODO make project step from query.target
# TODO make project step from query.target

# TODO support complex query. Only one table is supported at the moment.
# if not isinstance(join_left.from_table, Identifier):
# raise PlanningException(f'Statement not supported: {query.to_string()}')
# TODO support complex query. Only one table is supported at the moment.
# if not isinstance(join_left.from_table, Identifier):
# raise PlanningException(f'Statement not supported: {query.to_string()}')

# move properties to upper query
query = join_left
# move properties to upper query
query = join_left

if query.from_table.alias is not None:
table_alias = [query.from_table.alias.parts[0]]
if query.from_table.alias is not None:
table_alias = [query.from_table.alias.parts[0]]
else:
table_alias = query.from_table.parts

# add latest to query.where
for cond in moved_conditions:
if query.where is not None:
query.where = BinaryOperation('and', args=[query.where, cond])
else:
table_alias = query.from_table.parts
query.where = cond

# add latest to query.where
for cond in moved_conditions:
if query.where is not None:
query.where = BinaryOperation('and', args=[query.where, cond])
else:
query.where = cond

def add_aliases(node, is_table, **kwargs):
if not is_table and isinstance(node, Identifier):
if len(node.parts) == 1:
# add table alias to field
node.parts = table_alias + node.parts

query_traversal(query.where, add_aliases)

if isinstance(query.from_table, Identifier):
# DBT workaround: allow use tables without integration.
# if table.part[0] not in integration - take integration name from create table command
if (
integration is not None
and query.from_table.parts[0] not in self.databases
):
# add integration name to table
query.from_table.parts.insert(0, integration)
def add_aliases(node, is_table, **kwargs):
if not is_table and isinstance(node, Identifier):
if len(node.parts) == 1:
# add table alias to field
node.parts = table_alias + node.parts

query_traversal(query.where, add_aliases)

if isinstance(query.from_table, Identifier):
# DBT workaround: allow use tables without integration.
# if table.part[0] not in integration - take integration name from create table command
if (
integration is not None
and query.from_table.parts[0] not in self.databases
):
# add integration name to table
query.from_table.parts.insert(0, integration)

join_left = join_left.from_table

if orig_query.limit is not None:
if query.limit is None or query.limit.value > orig_query.limit.value:
query.limit = orig_query.limit
query.parentheses = False
query.alias = None

return query, join_left

def plan_join(self, query, integration=None):
orig_query = query

join_left = join_left.from_table
join = query.from_table
join_left = join.left
join_right = join.right

if orig_query.limit is not None:
if query.limit is None or query.limit.value > orig_query.limit.value:
query.limit = orig_query.limit
query.parentheses = False
query.alias = None
if isinstance(join_left, Select) and isinstance(join_left.from_table, Identifier):
if self.is_predictor(join_right) and self.get_predictor(join_right).get('timeseries'):
query, join_left = self.adapt_dbt_query(query, integration)

aliased_fields = self.get_aliased_fields(query.targets)

Expand Down
71 changes: 26 additions & 45 deletions tests/test_planner/test_join_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def test_subselect(self):
sql = f'''
SELECT *
FROM (
select * from int.covid
select col from int.covid
limit 10
) as t
join mindsdb.pred
Expand All @@ -447,13 +447,15 @@ def test_subselect(self):
default_namespace='mindsdb',
steps=[
FetchDataframeStep(integration='int',
query=parse_sql('select * from covid limit 5')),
ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred')),
JoinStep(left=Result(0), right=Result(1),
query=Join(left=Identifier('result_0'),
right=Identifier('result_1'),
query=parse_sql('select covid.col as col from covid limit 10')),
SubSelectStep(query=Select(targets=[Star()]), dataframe=Result(0), table_name='t'),
ApplyPredictorStep(namespace='mindsdb', dataframe=Result(1), predictor=Identifier('pred')),
JoinStep(left=Result(1), right=Result(2),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
join_type=JoinType.JOIN)),
ProjectStep(dataframe=Result(2), columns=[Star()])
LimitOffsetStep(dataframe=Result(3), limit=5),
ProjectStep(dataframe=Result(4), columns=[Star()])
],
)

Expand All @@ -467,55 +469,34 @@ def test_subselect(self):
for i in range(len(plan.steps)):
assert plan.steps[i] == expected_plan.steps[i]

# nested limit is lesser
sql = f'''
SELECT *
FROM (
select * from int.covid
limit 5
) as t
join mindsdb.pred
limit 50
'''

plan = plan_query(
query,
integrations=['int'],
predictor_namespace='mindsdb',
default_namespace='mindsdb',
predictor_metadata={'pred': {}}
)
for i in range(len(plan.steps)):
assert plan.steps[i] == expected_plan.steps[i]

# nested select without limit
# only nested select with limit
sql = f'''
SELECT *
FROM (
select * from int.covid
join int.info
limit 5
) as t
join mindsdb.pred
limit 5
'''

plan = plan_query(
query,
integrations=['int'],
predictor_namespace='mindsdb',
query = parse_sql(sql, dialect='mindsdb')

expected_plan = QueryPlan(
default_namespace='mindsdb',
predictor_metadata={'pred': {}}
steps=[
FetchDataframeStep(integration='int',
query=parse_sql('select * from covid join info limit 5')),
SubSelectStep(query=Select(targets=[Star()]), dataframe=Result(0), table_name='t'),
ApplyPredictorStep(namespace='mindsdb', dataframe=Result(1), predictor=Identifier('pred')),
JoinStep(left=Result(1), right=Result(2),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
join_type=JoinType.JOIN)),
ProjectStep(dataframe=Result(3), columns=[Star()])
],
)
for i in range(len(plan.steps)):
assert plan.steps[i] == expected_plan.steps[i]

# only nested select with limit
sql = f'''
SELECT *
FROM (
select * from int.covid limit 5
) as t
join mindsdb.pred
'''

plan = plan_query(
query,
Expand Down

0 comments on commit f7660ba

Please sign in to comment.