Skip to content

Commit

Permalink
Merge pull request #374 from DagsHub/bug/sequential-querying
Browse files Browse the repository at this point in the history
Fix sequential querying composition bug
  • Loading branch information
kbolashev authored Oct 5, 2023
2 parents 93d1a1a + 5b454c3 commit 1cf23d7
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 26 deletions.
23 changes: 12 additions & 11 deletions dagshub/data_engine/model/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,19 +509,20 @@ def __getitem__(self, other: Union[slice, str, "Datasource"]):
if type(other) is str:
if not self.has_field(other):
raise FieldNotFoundError(other)
new_ds._query = DatasourceQuery(other)
return new_ds
else:
# "index" is a datasource with a query - compose with "and"
# Example:
# ds = Dataset()
# filtered_ds = ds[ds["aaa"] > 5]
# filtered_ds2 = filtered_ds[filtered_ds["bbb"] < 4]
other_query = DatasourceQuery(other)
if self._query.is_empty:
new_ds._query = other._query
return new_ds
new_ds._query = other_query
else:
return other.__and__(self)
new_ds._query.compose("and", other_query)
return new_ds
# "index" is a datasource with a query - return the datasource inside
# Example:
# ds = Dataset()
# filtered_ds = ds[ds["aaa"] > 5]
# filtered_ds2 = filtered_ds[filtered_ds["bbb"] < 4]
# filtered_ds2 will be "aaa" > 5 AND "bbb" < 4
if isinstance(other, Datasource):
return other

def __gt__(self, other: object):
self._test_not_comparing_other_ds(other)
Expand Down
34 changes: 20 additions & 14 deletions dagshub/data_engine/model/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,15 @@ class FieldFilterOperand(enum.Enum):
for k, v in fieldFilterOperandMap.items():
fieldFilterOperandMapReverseMap[v.value] = k

UNFILLED_NODE_TAG = "undefined"


class DatasourceQuery:
def __init__(self, column_or_query: Optional[Union[str, "DatasourceQuery"]] = None):
self._operand_tree: Tree = Tree()
self._column_filter: Optional[str] = None # for storing filters when user does ds["column"]
if type(column_or_query) is str:
# If it's ds["column"] then the root node is just the column name
self._column_filter = column_or_query
# If it's ds["column"] then the root node is just the column name, will be filled later
self._operand_tree.create_node(UNFILLED_NODE_TAG, data={"field": column_or_query})
elif column_or_query is not None:
self._operand_tree.create_node(column_or_query)

Expand All @@ -68,13 +69,17 @@ def __repr__(self):

@property
def column_filter(self) -> Optional[str]:
return self._column_filter
filter_node = self._column_filter_node
if filter_node is None:
return None
return filter_node.data["field"]

def compose(self, op: str, other: Optional[Union[str, int, float, "DatasourceQuery", "Datasource"]]):
if self._column_filter is not None:
# Just the column is in the query - compose into a tree
self._operand_tree.create_node(op, data={"field": self._column_filter, "value": other})
self._column_filter = None
if self._column_filter_node is not None:
# If there was an unfilled query node with a column - put the operand in that node
node = self._column_filter_node
node.tag = op
node.data.update({"value": other})
elif op == "isnull":
# Can only do isnull on the column filter, if we got here, there's something wrong
raise RuntimeError("is_null operation can only be done on a column (e.g. ds['col1'].is_null())")
Expand All @@ -98,14 +103,18 @@ def compose(self, op: str, other: Optional[Union[str, int, float, "DatasourceQue
if self.is_empty:
self._operand_tree = other._operand_tree
return
elif other.is_empty:
elif other.is_empty and other._column_filter_node is None:
return
composite_tree = Tree()
root_node = composite_tree.create_node(op)
composite_tree.paste(root_node.identifier, self._operand_tree)
composite_tree.paste(root_node.identifier, other._operand_tree)
self._operand_tree = composite_tree

@property
def _column_filter_node(self) -> Node:
return next(self._operand_tree.filter_nodes(lambda n: n.tag == UNFILLED_NODE_TAG), None)

@property
def _operand_root(self) -> Node:
return self._operand_tree[self._operand_tree.root]
Expand Down Expand Up @@ -198,12 +207,9 @@ def to_dict(self):

def __deepcopy__(self, memodict={}):
q = DatasourceQuery()
if self._column_filter is not None:
q._column_filter = self._column_filter
else:
q._operand_tree = Tree(tree=self._operand_tree, deep=True)
q._operand_tree = Tree(tree=self._operand_tree, deep=True)
return q

@property
def is_empty(self):
return self._column_filter is not None or self._operand_tree.root is None
return self._operand_tree.root is None or self._column_filter_node is not None
20 changes: 19 additions & 1 deletion tests/data_engine/test_querying.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_query_single_column(ds):
assert type(ds2) is Datasource

q = ds2.get_query()
print(q._column_filter == column_name)
print(q.column_filter == column_name)


def test_simple_filter(ds):
Expand Down Expand Up @@ -438,3 +438,21 @@ def test_blob_deserialization(ds):
}
deserialized = DatasourceQuery.deserialize(serialized)
assert queried.get_query().serialize_graphql() == deserialized.serialize_graphql()


def test_sequential_querying(ds):
add_blob_fields(ds, "col1")
add_blob_fields(ds, "col2")
queried = ds["col1"].is_null()
queried2 = queried["col2"].is_null()

expected = {
"and": {
"children": [
{"isnull": {"data": {"field": "col1", "value": b""}}},
{"isnull": {"data": {"field": "col2", "value": b""}}},
],
"data": None,
}
}
assert queried2.get_query().to_dict() == expected

0 comments on commit 1cf23d7

Please sign in to comment.