Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: Can't update inside of a metadata context when iterating over a different query #541

Merged
merged 1 commit into from
Oct 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions dagshub/data_engine/model/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def to_dict(self, ds: "Datasource") -> Dict[str, Any]:
return res_dict


_metadata_contexts: Dict[Union[int, str], "MetadataContextManager"] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too bad we need a global thing, but i can't think of any other solution



class Datasource:
def __init__(
self,
Expand All @@ -169,8 +172,6 @@ def __init__(
if query is None:
query = DatasourceQuery()
self._query = query
# a per datasource context used for dict update syntax
self._implicit_update_ctx: Optional[MetadataContextManager] = None
# this ref marks if source is currently used in
# meta-data update 'with' block
self._explicit_update_ctx: Optional[MetadataContextManager] = None
Expand All @@ -182,12 +183,6 @@ def __init__(
def has_explicit_context(self):
return self._explicit_update_ctx is not None

def _get_source_implicit_metadata_entries(self):
if self._implicit_update_ctx is not None:
return self._implicit_update_ctx.get_metadata_entries()
else:
return []

@property
def source(self) -> "DatasourceState":
return self._source
Expand All @@ -209,9 +204,6 @@ def __deepcopy__(self, memodict={}) -> "Datasource":
res = Datasource(self._source, self._query.__deepcopy__())
res.assigned_dataset = self.assigned_dataset

# Carry over the update context, that way we'll keep track of the metadata being uploaded
res._implicit_update_ctx = self._implicit_update_ctx

return res

def get_query(self) -> "DatasourceQuery":
Expand Down Expand Up @@ -489,21 +481,25 @@ def apply_field_changes(self, field_builders: List[MetadataFieldBuilder]):

@property
def implicit_update_context(self) -> "MetadataContextManager":
if not self._implicit_update_ctx:
self._implicit_update_ctx = MetadataContextManager(self)
"""
Context that is used when updating metadata through ``dp[field] = value`` syntax, can be created on demand.

return self._implicit_update_ctx
:meta private:
"""
key = self.source.id
if key not in _metadata_contexts:
_metadata_contexts[key] = MetadataContextManager(self)
return _metadata_contexts[key]

def upload_metadata_of_implicit_context(self):
"""
commit meta data changes done in dictionary assignment context
:meta private:
"""
if self._implicit_update_ctx:
try:
self._upload_metadata(self._get_source_implicit_metadata_entries())
finally:
self._implicit_update_ctx = None
try:
self._upload_metadata(self.implicit_update_context.get_metadata_entries())
finally:
self.implicit_update_context.clear()

def metadata_context(self) -> ContextManager["MetadataContextManager"]:
"""
Expand All @@ -526,9 +522,12 @@ def func():
self._explicit_update_ctx = ctx
yield ctx
try:
self._upload_metadata(ctx.get_metadata_entries() + self._get_source_implicit_metadata_entries())
entries = ctx.get_metadata_entries() + self.implicit_update_context.get_metadata_entries()
self._upload_metadata(entries)
finally:
self._implicit_update_ctx = None
# Clear the implicit context because it can persist
self.implicit_update_context.clear()
# The explicit one created with with: can go away
self._explicit_update_ctx = None

return func()
Expand Down Expand Up @@ -1826,6 +1825,12 @@ def update_metadata(self, datapoints: Union[List[str], str], metadata: Dict[str,
def get_metadata_entries(self):
return self._metadata_entries

def clear(self):
self._metadata_entries.clear()

def __len__(self):
return len(self._metadata_entries)


def _get_datetime_utc_offset(t):
"""
Expand Down
Loading