From 697a93fc57fd84cad6ce6a39c8121d8a257881fe Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Thu, 17 Oct 2024 17:14:06 +0200 Subject: [PATCH 001/124] Initial commit, nuking all metadata responses and seeing what breaks --- src/zenml/client.py | 76 +----- src/zenml/metadata/lazy_load.py | 67 ------ src/zenml/model/model.py | 3 +- src/zenml/models/__init__.py | 12 - src/zenml/models/v2/core/artifact_version.py | 20 +- src/zenml/models/v2/core/model_version.py | 8 +- src/zenml/models/v2/core/pipeline_run.py | 8 +- src/zenml/models/v2/core/run_metadata.py | 219 +----------------- src/zenml/models/v2/core/step_run.py | 8 +- src/zenml/orchestrators/input_utils.py | 5 +- src/zenml/steps/base_step.py | 9 - src/zenml/steps/entrypoint_function_utils.py | 6 +- src/zenml/zen_server/rbac/utils.py | 2 - .../routers/run_metadata_endpoints.py | 96 -------- .../routers/workspaces_endpoints.py | 7 +- src/zenml/zen_server/zen_server_api.py | 2 - src/zenml/zen_stores/rest_zen_store.py | 57 +---- .../zen_stores/schemas/artifact_schemas.py | 2 +- src/zenml/zen_stores/schemas/model_schemas.py | 5 +- .../schemas/pipeline_run_schemas.py | 2 +- .../schemas/run_metadata_schemas.py | 87 ++++--- src/zenml/zen_stores/sql_zen_store.py | 67 +----- src/zenml/zen_stores/zen_store_interface.py | 44 +--- .../functional/models/test_artifact.py | 4 +- 24 files changed, 81 insertions(+), 735 deletions(-) delete mode 100644 src/zenml/metadata/lazy_load.py delete mode 100644 src/zenml/zen_server/routers/run_metadata_endpoints.py diff --git a/src/zenml/client.py b/src/zenml/client.py index dad2f6e123e..5a1f4da469e 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -136,9 +136,7 @@ PipelineResponse, PipelineRunFilter, PipelineRunResponse, - RunMetadataFilter, RunMetadataRequest, - RunMetadataResponse, RunTemplateFilter, RunTemplateRequest, RunTemplateResponse, @@ -4390,7 +4388,7 @@ def create_run_metadata( resource_id: UUID, resource_type: MetadataResourceTypes, stack_component_id: Optional[UUID] = None, - ) -> List[RunMetadataResponse]: + ) -> None: """Create run metadata. Args: @@ -4403,7 +4401,7 @@ def create_run_metadata( the metadata. Returns: - The created metadata, as string to model dictionary. + None """ from zenml.metadata.metadata_types import get_metadata_type @@ -4438,74 +4436,8 @@ def create_run_metadata( values=values, types=types, ) - return self.zen_store.create_run_metadata(run_metadata) - - def list_run_metadata( - self, - sort_by: str = "created", - page: int = PAGINATION_STARTING_PAGE, - size: int = PAGE_SIZE_DEFAULT, - logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - workspace_id: Optional[UUID] = None, - user_id: Optional[UUID] = None, - resource_id: Optional[UUID] = None, - resource_type: Optional[MetadataResourceTypes] = None, - stack_component_id: Optional[UUID] = None, - key: Optional[str] = None, - value: Optional["MetadataType"] = None, - type: Optional[str] = None, - hydrate: bool = False, - ) -> Page[RunMetadataResponse]: - """List run metadata. - - Args: - sort_by: The field to sort the results by. - page: The page number to return. - size: The number of results to return per page. - logical_operator: The logical operator to use for filtering. - id: The ID of the metadata. - created: The creation time of the metadata. - updated: The last update time of the metadata. - workspace_id: The ID of the workspace the metadata belongs to. - user_id: The ID of the user that created the metadata. - resource_id: The ID of the resource the metadata belongs to. - resource_type: The type of the resource the metadata belongs to. - stack_component_id: The ID of the stack component that produced - the metadata. - key: The key of the metadata. - value: The value of the metadata. - type: The type of the metadata. - hydrate: Flag deciding whether to hydrate the output model(s) - by including metadata fields in the response. - - Returns: - The run metadata. - """ - metadata_filter_model = RunMetadataFilter( - sort_by=sort_by, - page=page, - size=size, - logical_operator=logical_operator, - id=id, - created=created, - updated=updated, - workspace_id=workspace_id, - user_id=user_id, - resource_id=resource_id, - resource_type=resource_type, - stack_component_id=stack_component_id, - key=key, - value=value, - type=type, - ) - metadata_filter_model.set_scope_workspace(self.active_workspace.id) - return self.zen_store.list_run_metadata( - metadata_filter_model, - hydrate=hydrate, - ) + self.zen_store.create_run_metadata(run_metadata) + return None # -------------------------------- Secrets --------------------------------- diff --git a/src/zenml/metadata/lazy_load.py b/src/zenml/metadata/lazy_load.py deleted file mode 100644 index 4064450142a..00000000000 --- a/src/zenml/metadata/lazy_load.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) ZenML GmbH 2024. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Run Metadata Lazy Loader definition.""" - -from typing import TYPE_CHECKING, Optional - -if TYPE_CHECKING: - from zenml.models import RunMetadataResponse - - -class RunMetadataLazyGetter: - """Run Metadata Lazy Getter helper class. - - It serves the purpose to feed back to the user the metadata - lazy loader wrapper for any given key, if called inside a pipeline - design time context. - """ - - def __init__( - self, - _lazy_load_model_name: str, - _lazy_load_model_version: Optional[str], - _lazy_load_artifact_name: Optional[str] = None, - _lazy_load_artifact_version: Optional[str] = None, - ): - """Initialize a RunMetadataLazyGetter. - - Args: - _lazy_load_model_name: The model name. - _lazy_load_model_version: The model version. - _lazy_load_artifact_name: The artifact name. - _lazy_load_artifact_version: The artifact version. - """ - self._lazy_load_model_name = _lazy_load_model_name - self._lazy_load_model_version = _lazy_load_model_version - self._lazy_load_artifact_name = _lazy_load_artifact_name - self._lazy_load_artifact_version = _lazy_load_artifact_version - - def __getitem__(self, key: str) -> "RunMetadataResponse": - """Get the metadata for the given key. - - Args: - key: The metadata key. - - Returns: - The metadata lazy loader wrapper for the given key. - """ - from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse - - return LazyRunMetadataResponse( - lazy_load_model_name=self._lazy_load_model_name, - lazy_load_model_version=self._lazy_load_model_version, - lazy_load_artifact_name=self._lazy_load_artifact_name, - lazy_load_artifact_version=self._lazy_load_artifact_version, - lazy_load_metadata_name=key, - ) diff --git a/src/zenml/model/model.py b/src/zenml/model/model.py index a759a7637b3..8af7cc595c6 100644 --- a/src/zenml/model/model.py +++ b/src/zenml/model/model.py @@ -43,7 +43,6 @@ ModelResponse, ModelVersionResponse, PipelineRunResponse, - RunMetadataResponse, StepRunResponse, ) @@ -350,7 +349,7 @@ def log_metadata( ) @property - def run_metadata(self) -> Dict[str, "RunMetadataResponse"]: + def run_metadata(self) -> Dict[str, "MetadataType"]: """Get model version run metadata. Returns: diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index b1d9160b61b..042ed8a1185 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -238,12 +238,7 @@ ) from zenml.models.v2.base.base_plugin_flavor import BasePluginFlavorResponse from zenml.models.v2.core.run_metadata import ( - LazyRunMetadataResponse, RunMetadataRequest, - RunMetadataFilter, - RunMetadataResponse, - RunMetadataResponseBody, - RunMetadataResponseMetadata, ) from zenml.models.v2.core.schedule import ( ScheduleRequest, @@ -416,7 +411,6 @@ FlavorResponseBody.model_rebuild() FlavorResponseMetadata.model_rebuild() LazyArtifactVersionResponse.model_rebuild() -LazyRunMetadataResponse.model_rebuild() ModelResponseBody.model_rebuild() ModelResponseMetadata.model_rebuild() ModelVersionResponseBody.model_rebuild() @@ -442,8 +436,6 @@ RunTemplateResponseMetadata.model_rebuild() RunTemplateResponseResources.model_rebuild() RunTemplateResponseBody.model_rebuild() -RunMetadataResponseBody.model_rebuild() -RunMetadataResponseMetadata.model_rebuild() ScheduleResponseBody.model_rebuild() ScheduleResponseMetadata.model_rebuild() SecretResponseBody.model_rebuild() @@ -634,10 +626,6 @@ "RunTemplateResponseResources", "RunTemplateFilter", "RunMetadataRequest", - "RunMetadataFilter", - "RunMetadataResponse", - "RunMetadataResponseBody", - "RunMetadataResponseMetadata", "ScheduleRequest", "ScheduleUpdate", "ScheduleFilter", diff --git a/src/zenml/models/v2/core/artifact_version.py b/src/zenml/models/v2/core/artifact_version.py index 7a1da75d179..837e7e08c4c 100644 --- a/src/zenml/models/v2/core/artifact_version.py +++ b/src/zenml/models/v2/core/artifact_version.py @@ -30,6 +30,7 @@ from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH from zenml.enums import ArtifactType, GenericFilterOps from zenml.logger import get_logger +from zenml.metadata.metadata_types import MetadataType from zenml.models.v2.base.filter import StrFilter from zenml.models.v2.base.scoped import ( WorkspaceScopedRequest, @@ -50,9 +51,6 @@ ArtifactVisualizationResponse, ) from zenml.models.v2.core.pipeline_run import PipelineRunResponse - from zenml.models.v2.core.run_metadata import ( - RunMetadataResponse, - ) from zenml.models.v2.core.step_run import StepRunResponse logger = get_logger(__name__) @@ -193,7 +191,7 @@ class ArtifactVersionResponseMetadata(WorkspaceScopedResponseMetadata): visualizations: Optional[List["ArtifactVisualizationResponse"]] = Field( default=None, title="Visualizations of the artifact." ) - run_metadata: Dict[str, "RunMetadataResponse"] = Field( + run_metadata: Dict[str, MetadataType] = Field( default={}, title="Metadata of the artifact." ) @@ -306,7 +304,7 @@ def visualizations( return self.get_metadata().visualizations @property - def run_metadata(self) -> Dict[str, "RunMetadataResponse"]: + def run_metadata(self) -> Dict[str, MetadataType]: """The `metadata` property. Returns: @@ -632,17 +630,11 @@ def get_metadata(self) -> None: # type: ignore[override] ) @property - def run_metadata(self) -> Dict[str, "RunMetadataResponse"]: + def run_metadata(self) -> Dict[str, MetadataType]: """The `metadata` property in lazy loading mode. Returns: getter of lazy responses for internal use. """ - from zenml.metadata.lazy_load import RunMetadataLazyGetter - - return RunMetadataLazyGetter( # type: ignore[return-value] - self.lazy_load_model_name, - self.lazy_load_model_version, - self.lazy_load_name, - self.lazy_load_version, - ) + # todo: figure this out + pass diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index 817b0fe7353..b25d6a7ffb1 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -29,6 +29,7 @@ from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH from zenml.enums import ModelStages +from zenml.metadata.metadata_types import MetadataType from zenml.models.v2.base.filter import AnyQuery from zenml.models.v2.base.page import Page from zenml.models.v2.base.scoped import ( @@ -49,9 +50,6 @@ from zenml.models.v2.core.artifact_version import ArtifactVersionResponse from zenml.models.v2.core.model import ModelResponse from zenml.models.v2.core.pipeline_run import PipelineRunResponse - from zenml.models.v2.core.run_metadata import ( - RunMetadataResponse, - ) from zenml.zen_stores.schemas import BaseSchema AnySchema = TypeVar("AnySchema", bound=BaseSchema) @@ -193,7 +191,7 @@ class ModelVersionResponseMetadata(WorkspaceScopedResponseMetadata): max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - run_metadata: Dict[str, "RunMetadataResponse"] = Field( + run_metadata: Dict[str, MetadataType] = Field( description="Metadata linked to the model version", default={}, ) @@ -304,7 +302,7 @@ def description(self) -> Optional[str]: return self.get_metadata().description @property - def run_metadata(self) -> Optional[Dict[str, "RunMetadataResponse"]]: + def run_metadata(self) -> Optional[Dict[str, MetadataType]]: """The `run_metadata` property. Returns: diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index cc3f5ad945c..4b824047b1e 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -30,6 +30,7 @@ from zenml.config.pipeline_configurations import PipelineConfiguration from zenml.constants import STR_FIELD_MAX_LENGTH from zenml.enums import ExecutionStatus +from zenml.metadata.metadata_types import MetadataType from zenml.models.v2.base.scoped import ( WorkspaceScopedFilter, WorkspaceScopedRequest, @@ -52,9 +53,6 @@ from zenml.models.v2.core.pipeline_build import ( PipelineBuildResponse, ) - from zenml.models.v2.core.run_metadata import ( - RunMetadataResponse, - ) from zenml.models.v2.core.schedule import ScheduleResponse from zenml.models.v2.core.stack import StackResponse from zenml.models.v2.core.step_run import StepRunResponse @@ -191,7 +189,7 @@ class PipelineRunResponseBody(WorkspaceScopedResponseBody): class PipelineRunResponseMetadata(WorkspaceScopedResponseMetadata): """Response metadata for pipeline runs.""" - run_metadata: Dict[str, "RunMetadataResponse"] = Field( + run_metadata: Dict[str, MetadataType] = Field( default={}, title="Metadata associated with this pipeline run.", ) @@ -451,7 +449,7 @@ def model_version_id(self) -> Optional[UUID]: return self.get_body().model_version_id @property - def run_metadata(self) -> Dict[str, "RunMetadataResponse"]: + def run_metadata(self) -> Dict[str, MetadataType]: """The `run_metadata` property. Returns: diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index 99b706a529b..c4a2ef8e678 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -13,21 +13,15 @@ # permissions and limitations under the License. """Models representing run metadata.""" -from typing import Any, Dict, Optional, Union +from typing import Dict, Optional from uuid import UUID -from pydantic import Field, field_validator +from pydantic import Field -from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH from zenml.enums import MetadataResourceTypes from zenml.metadata.metadata_types import MetadataType, MetadataTypeEnum from zenml.models.v2.base.scoped import ( - WorkspaceScopedFilter, WorkspaceScopedRequest, - WorkspaceScopedResponse, - WorkspaceScopedResponseBody, - WorkspaceScopedResponseMetadata, - WorkspaceScopedResponseResources, ) # ------------------ Request Model ------------------ @@ -51,212 +45,3 @@ class RunMetadataRequest(WorkspaceScopedRequest): types: Dict[str, "MetadataTypeEnum"] = Field( title="The types of the metadata to be created.", ) - - -# ------------------ Update Model ------------------ - -# There is no update model for run metadata. - -# ------------------ Response Model ------------------ - - -class RunMetadataResponseBody(WorkspaceScopedResponseBody): - """Response body for run metadata.""" - - key: str = Field(title="The key of the metadata.") - value: MetadataType = Field( - title="The value of the metadata.", union_mode="smart" - ) - type: MetadataTypeEnum = Field(title="The type of the metadata.") - - @field_validator("key", "type") - @classmethod - def str_field_max_length_check(cls, value: Any) -> Any: - """Checks if the length of the value exceeds the maximum str length. - - Args: - value: the value set in the field - - Returns: - the value itself. - - Raises: - AssertionError: if the length of the field is longer than the - maximum threshold. - """ - assert len(str(value)) < STR_FIELD_MAX_LENGTH, ( - "The length of the value for this field can not " - f"exceed {STR_FIELD_MAX_LENGTH}" - ) - return value - - @field_validator("value") - @classmethod - def text_field_max_length_check(cls, value: Any) -> Any: - """Checks if the length of the value exceeds the maximum text length. - - Args: - value: the value set in the field - - Returns: - the value itself. - - Raises: - AssertionError: if the length of the field is longer than the - maximum threshold. - """ - assert len(str(value)) < TEXT_FIELD_MAX_LENGTH, ( - "The length of the value for this field can not " - f"exceed {TEXT_FIELD_MAX_LENGTH}" - ) - return value - - -class RunMetadataResponseMetadata(WorkspaceScopedResponseMetadata): - """Response metadata for run metadata.""" - - resource_id: UUID = Field( - title="The ID of the resource that this metadata belongs to.", - ) - resource_type: MetadataResourceTypes = Field( - title="The type of the resource that this metadata belongs to.", - ) - stack_component_id: Optional[UUID] = Field( - title="The ID of the stack component that this metadata belongs to." - ) - - -class RunMetadataResponseResources(WorkspaceScopedResponseResources): - """Class for all resource models associated with the run metadata entity.""" - - -class RunMetadataResponse( - WorkspaceScopedResponse[ - RunMetadataResponseBody, - RunMetadataResponseMetadata, - RunMetadataResponseResources, - ] -): - """Response model for run metadata.""" - - def get_hydrated_version(self) -> "RunMetadataResponse": - """Get the hydrated version of this run metadata. - - Returns: - an instance of the same entity with the metadata field attached. - """ - from zenml.client import Client - - return Client().zen_store.get_run_metadata(self.id) - - # Body and metadata properties - @property - def key(self) -> str: - """The `key` property. - - Returns: - the value of the property. - """ - return self.get_body().key - - @property - def value(self) -> MetadataType: - """The `value` property. - - Returns: - the value of the property. - """ - return self.get_body().value - - @property - def type(self) -> MetadataTypeEnum: - """The `type` property. - - Returns: - the value of the property. - """ - return self.get_body().type - - @property - def resource_id(self) -> UUID: - """The `resource_id` property. - - Returns: - the value of the property. - """ - return self.get_metadata().resource_id - - @property - def resource_type(self) -> MetadataResourceTypes: - """The `resource_type` property. - - Returns: - the value of the property. - """ - return MetadataResourceTypes(self.get_metadata().resource_type) - - @property - def stack_component_id(self) -> Optional[UUID]: - """The `stack_component_id` property. - - Returns: - the value of the property. - """ - return self.get_metadata().stack_component_id - - -# ------------------ Filter Model ------------------ - - -class RunMetadataFilter(WorkspaceScopedFilter): - """Model to enable advanced filtering of run metadata.""" - - resource_id: Optional[Union[str, UUID]] = Field( - default=None, union_mode="left_to_right" - ) - resource_type: Optional[MetadataResourceTypes] = None - stack_component_id: Optional[Union[str, UUID]] = Field( - default=None, union_mode="left_to_right" - ) - key: Optional[str] = None - type: Optional[Union[str, MetadataTypeEnum]] = Field( - default=None, union_mode="left_to_right" - ) - - -# -------------------- Lazy Loader -------------------- - - -class LazyRunMetadataResponse(RunMetadataResponse): - """Lazy run metadata response. - - Used if the run metadata is accessed from the model in - a pipeline context available only during pipeline compilation. - """ - - id: Optional[UUID] = None # type: ignore[assignment] - lazy_load_artifact_name: Optional[str] = None - lazy_load_artifact_version: Optional[str] = None - lazy_load_metadata_name: Optional[str] = None - lazy_load_model_name: str - lazy_load_model_version: Optional[str] = None - - def get_body(self) -> None: # type: ignore[override] - """Protects from misuse of the lazy loader. - - Raises: - RuntimeError: always - """ - raise RuntimeError( - "Cannot access run metadata body before pipeline runs." - ) - - def get_metadata(self) -> None: # type: ignore[override] - """Protects from misuse of the lazy loader. - - Raises: - RuntimeError: always - """ - raise RuntimeError( - "Cannot access run metadata metadata before pipeline runs." - ) diff --git a/src/zenml/models/v2/core/step_run.py b/src/zenml/models/v2/core/step_run.py index 8284a19bcc9..662530d1420 100644 --- a/src/zenml/models/v2/core/step_run.py +++ b/src/zenml/models/v2/core/step_run.py @@ -22,6 +22,7 @@ from zenml.config.step_configurations import StepConfiguration, StepSpec from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH from zenml.enums import ExecutionStatus +from zenml.metadata.metadata_types import MetadataType from zenml.models.v2.base.scoped import ( WorkspaceScopedFilter, WorkspaceScopedRequest, @@ -38,9 +39,6 @@ LogsRequest, LogsResponse, ) - from zenml.models.v2.core.run_metadata import ( - RunMetadataResponse, - ) # ------------------ Request Model ------------------ @@ -230,7 +228,7 @@ class StepRunResponseMetadata(WorkspaceScopedResponseMetadata): title="The IDs of the parent steps of this step run.", default_factory=list, ) - run_metadata: Dict[str, "RunMetadataResponse"] = Field( + run_metadata: Dict[str, MetadataType] = Field( title="Metadata associated with this step run.", default={}, ) @@ -466,7 +464,7 @@ def parent_step_ids(self) -> List[UUID]: return self.get_metadata().parent_step_ids @property - def run_metadata(self) -> Dict[str, "RunMetadataResponse"]: + def run_metadata(self) -> Dict[str, MetadataType]: """The `run_metadata` property. Returns: diff --git a/src/zenml/orchestrators/input_utils.py b/src/zenml/orchestrators/input_utils.py index 75127b7d81e..9d3c72d427e 100644 --- a/src/zenml/orchestrators/input_utils.py +++ b/src/zenml/orchestrators/input_utils.py @@ -19,6 +19,7 @@ from zenml.client import Client from zenml.config.step_configurations import Step from zenml.exceptions import InputResolutionError +from zenml.metadata.metadata_types import MetadataType from zenml.utils import pagination_utils if TYPE_CHECKING: @@ -45,7 +46,7 @@ def resolve_step_inputs( The IDs of the input artifact versions and the IDs of parent steps of the current step. """ - from zenml.models import ArtifactVersionResponse, RunMetadataResponse + from zenml.models import ArtifactVersionResponse current_run_steps = { run_step.name: run_step @@ -142,7 +143,7 @@ def resolve_step_inputs( value_ = cll_.evaluate() if isinstance(value_, ArtifactVersionResponse): input_artifacts[name] = value_ - elif isinstance(value_, RunMetadataResponse): + elif isinstance(value_, MetadataType): step.config.parameters[name] = value_.value else: step.config.parameters[name] = value_ diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index 2a3d324608c..efea89a606f 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -330,7 +330,6 @@ def _parse_call_args( from zenml.models.v2.core.artifact_version import ( LazyArtifactVersionResponse, ) - from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse signature = inspect.signature(self.entrypoint, follow_wrapped=True) @@ -378,14 +377,6 @@ def _parse_call_args( artifact_version=value.lazy_load_version, metadata_name=None, ) - elif isinstance(value, LazyRunMetadataResponse): - model_artifacts_or_metadata[key] = ModelVersionDataLazyLoader( - model_name=value.lazy_load_model_name, - model_version=value.lazy_load_model_version, - artifact_name=value.lazy_load_artifact_name, - artifact_version=value.lazy_load_artifact_version, - metadata_name=value.lazy_load_metadata_name, - ) elif isinstance(value, ClientLazyLoader): client_lazy_loaders[key] = value else: diff --git a/src/zenml/steps/entrypoint_function_utils.py b/src/zenml/steps/entrypoint_function_utils.py index a91f87131a7..733985d1174 100644 --- a/src/zenml/steps/entrypoint_function_utils.py +++ b/src/zenml/steps/entrypoint_function_utils.py @@ -136,10 +136,7 @@ def validate_input(self, key: str, value: Any) -> None: UnmaterializedArtifact, ) from zenml.client_lazy_loader import ClientLazyLoader - from zenml.models import ( - ArtifactVersionResponse, - RunMetadataResponse, - ) + from zenml.models import ArtifactVersionResponse if key not in self.inputs: raise KeyError( @@ -154,7 +151,6 @@ def validate_input(self, key: str, value: Any) -> None: StepArtifact, ExternalArtifact, ArtifactVersionResponse, - RunMetadataResponse, ClientLazyLoader, ), ): diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index 2dd0b2ef339..9e00e1a740e 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -404,7 +404,6 @@ def get_resource_type_for_model( PipelineDeploymentResponse, PipelineResponse, PipelineRunResponse, - RunMetadataResponse, RunTemplateResponse, SecretResponse, ServiceAccountResponse, @@ -437,7 +436,6 @@ def get_resource_type_for_model( ArtifactVersionResponse: ResourceType.ARTIFACT_VERSION, WorkspaceResponse: ResourceType.WORKSPACE, UserResponse: ResourceType.USER, - RunMetadataResponse: ResourceType.RUN_METADATA, PipelineDeploymentResponse: ResourceType.PIPELINE_DEPLOYMENT, PipelineBuildResponse: ResourceType.PIPELINE_BUILD, PipelineRunResponse: ResourceType.PIPELINE_RUN, diff --git a/src/zenml/zen_server/routers/run_metadata_endpoints.py b/src/zenml/zen_server/routers/run_metadata_endpoints.py deleted file mode 100644 index c3d97da7a0b..00000000000 --- a/src/zenml/zen_server/routers/run_metadata_endpoints.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Endpoint definitions for run metadata.""" - -from uuid import UUID - -from fastapi import APIRouter, Depends, Security - -from zenml.constants import API, RUN_METADATA, VERSION_1 -from zenml.models import Page, RunMetadataFilter, RunMetadataResponse -from zenml.zen_server.auth import AuthContext, authorize -from zenml.zen_server.exceptions import error_response -from zenml.zen_server.rbac.endpoint_utils import ( - verify_permissions_and_list_entities, -) -from zenml.zen_server.rbac.models import ResourceType -from zenml.zen_server.utils import ( - handle_exceptions, - make_dependable, - zen_store, -) - -router = APIRouter( - prefix=API + VERSION_1 + RUN_METADATA, - tags=["run_metadata"], - responses={401: error_response, 403: error_response}, -) - - -@router.get( - "", - response_model=Page[RunMetadataResponse], - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def list_run_metadata( - run_metadata_filter_model: RunMetadataFilter = Depends( - make_dependable(RunMetadataFilter) - ), - hydrate: bool = False, - _: AuthContext = Security(authorize), -) -> Page[RunMetadataResponse]: - """Get run metadata according to query filters. - - Args: - run_metadata_filter_model: Filter model used for pagination, sorting, - filtering. - hydrate: Flag deciding whether to hydrate the output model(s) - by including metadata fields in the response. - - Returns: - The pipeline runs according to query filters. - """ - return verify_permissions_and_list_entities( - filter_model=run_metadata_filter_model, - resource_type=ResourceType.RUN_METADATA, - list_method=zen_store().list_run_metadata, - hydrate=hydrate, - ) - - -@router.get( - "/{run_metadata_id}", - response_model=RunMetadataResponse, - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def get_run_metadata( - run_metadata_id: UUID, - hydrate: bool = False, - _: AuthContext = Security(authorize), -) -> RunMetadataResponse: - """Get run metadata by ID. - - Args: - run_metadata_id: The ID of run metadata. - hydrate: Flag deciding whether to hydrate the output model(s) - by including metadata fields in the response. - - Returns: - The run metadata response. - """ - return zen_store().get_run_metadata( - run_metadata_id=run_metadata_id, hydrate=hydrate - ) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 3495be8c0d0..6d4e3d038a5 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -74,7 +74,6 @@ PipelineRunRequest, PipelineRunResponse, RunMetadataRequest, - RunMetadataResponse, RunTemplateFilter, RunTemplateRequest, RunTemplateResponse, @@ -977,7 +976,6 @@ def get_or_create_pipeline_run( @router.post( WORKSPACES + "/{workspace_name_or_id}" + RUN_METADATA, - response_model=List[RunMetadataResponse], responses={401: error_response, 409: error_response, 422: error_response}, ) @handle_exceptions @@ -985,7 +983,7 @@ def create_run_metadata( workspace_name_or_id: Union[str, UUID], run_metadata: RunMetadataRequest, auth_context: AuthContext = Security(authorize), -) -> List[RunMetadataResponse]: +) -> None: """Creates run metadata. Args: @@ -1039,7 +1037,8 @@ def create_run_metadata( resource_type=ResourceType.RUN_METADATA, action=Action.CREATE ) - return zen_store().create_run_metadata(run_metadata) + zen_store().create_run_metadata(run_metadata) + return None @router.post( diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 488d536f25b..d91a7d580ac 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -64,7 +64,6 @@ pipeline_deployments_endpoints, pipelines_endpoints, plugin_endpoints, - run_metadata_endpoints, run_templates_endpoints, runs_endpoints, schedule_endpoints, @@ -420,7 +419,6 @@ async def dashboard(request: Request) -> Any: app.include_router(pipeline_builds_endpoints.router) app.include_router(pipeline_deployments_endpoints.router) app.include_router(runs_endpoints.router) -app.include_router(run_metadata_endpoints.router) app.include_router(run_templates_endpoints.router) app.include_router(schedule_endpoints.router) app.include_router(secrets_endpoints.router) diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index d049b55ff3b..0f5fcdbe32f 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -188,9 +188,7 @@ PipelineRunResponse, PipelineRunUpdate, PipelineUpdate, - RunMetadataFilter, RunMetadataRequest, - RunMetadataResponse, RunTemplateFilter, RunTemplateRequest, RunTemplateResponse, @@ -1939,9 +1937,7 @@ def get_or_create_run( # ----------------------------- Run Metadata ----------------------------- - def create_run_metadata( - self, run_metadata: RunMetadataRequest - ) -> List[RunMetadataResponse]: + def create_run_metadata(self, run_metadata: RunMetadataRequest) -> None: """Creates run metadata. Args: @@ -1951,55 +1947,8 @@ def create_run_metadata( The created run metadata. """ route = f"{WORKSPACES}/{str(run_metadata.workspace)}{RUN_METADATA}" - response_body = self.post(f"{route}", body=run_metadata) - result: List[RunMetadataResponse] = [] - if isinstance(response_body, list): - for metadata in response_body or []: - result.append(RunMetadataResponse.model_validate(metadata)) - return result - - def get_run_metadata( - self, run_metadata_id: UUID, hydrate: bool = True - ) -> RunMetadataResponse: - """Gets run metadata with the given ID. - - Args: - run_metadata_id: The ID of the run metadata to get. - hydrate: Flag deciding whether to hydrate the output model(s) - by including metadata fields in the response. - - Returns: - The run metadata. - """ - return self._get_resource( - resource_id=run_metadata_id, - route=RUN_METADATA, - response_model=RunMetadataResponse, - params={"hydrate": hydrate}, - ) - - def list_run_metadata( - self, - run_metadata_filter_model: RunMetadataFilter, - hydrate: bool = False, - ) -> Page[RunMetadataResponse]: - """List run metadata. - - Args: - run_metadata_filter_model: All filter parameters including - pagination params. - hydrate: Flag deciding whether to hydrate the output model(s) - by including metadata fields in the response. - - Returns: - The run metadata. - """ - return self._list_paginated_resources( - route=RUN_METADATA, - response_model=RunMetadataResponse, - filter_model=run_metadata_filter_model, - params={"hydrate": hydrate}, - ) + self.post(f"{route}", body=run_metadata) + return None # ----------------------------- Schedules ----------------------------- diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index 2095a82092b..a158f6e5c97 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -351,7 +351,7 @@ def to_model( artifact_store_id=self.artifact_store_id, producer_step_run_id=producer_step_run_id, visualizations=[v.to_model() for v in self.visualizations], - run_metadata={m.key: m.to_model() for m in self.run_metadata}, + run_metadata={m.key: m.value for m in self.run_metadata}, ) resources = None diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index df4a9834ffb..15737ec6613 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -378,10 +378,7 @@ def to_model( metadata = ModelVersionResponseMetadata( workspace=self.workspace.to_model(), description=self.description, - run_metadata={ - rm.key: rm.to_model(include_metadata=True) - for rm in self.run_metadata - }, + run_metadata={rm.key: rm.value for rm in self.run_metadata}, ) resources = None diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 1c84166b318..13b6c8caa92 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -271,7 +271,7 @@ def to_model( ) run_metadata = { - metadata_schema.key: metadata_schema.to_model() + metadata_schema.key: metadata_schema.value for metadata_schema in self.run_metadata } diff --git a/src/zenml/zen_stores/schemas/run_metadata_schemas.py b/src/zenml/zen_stores/schemas/run_metadata_schemas.py index f84e210d97d..4de528abd5a 100644 --- a/src/zenml/zen_stores/schemas/run_metadata_schemas.py +++ b/src/zenml/zen_stores/schemas/run_metadata_schemas.py @@ -13,20 +13,13 @@ # permissions and limitations under the License. """SQLModel implementation of pipeline run metadata tables.""" -import json -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, List, Optional from uuid import UUID from sqlalchemy import TEXT, VARCHAR, Column from sqlmodel import Field, Relationship from zenml.enums import MetadataResourceTypes -from zenml.metadata.metadata_types import MetadataTypeEnum -from zenml.models import ( - RunMetadataResponse, - RunMetadataResponseBody, - RunMetadataResponseMetadata, -) from zenml.zen_stores.schemas.base_schemas import BaseSchema from zenml.zen_stores.schemas.component_schemas import StackComponentSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field @@ -111,42 +104,42 @@ class RunMetadataSchema(BaseSchema, table=True): value: str = Field(sa_column=Column(TEXT, nullable=False)) type: str - def to_model( - self, - include_metadata: bool = False, - include_resources: bool = False, - **kwargs: Any, - ) -> "RunMetadataResponse": - """Convert a `RunMetadataSchema` to a `RunMetadataResponse`. - - Args: - include_metadata: Whether the metadata will be filled. - include_resources: Whether the resources will be filled. - **kwargs: Keyword arguments to allow schema specific logic - - - Returns: - The created `RunMetadataResponse`. - """ - body = RunMetadataResponseBody( - user=self.user.to_model() if self.user else None, - key=self.key, - created=self.created, - updated=self.updated, - value=json.loads(self.value), - type=MetadataTypeEnum(self.type), - ) - metadata = None - if include_metadata: - metadata = RunMetadataResponseMetadata( - workspace=self.workspace.to_model(), - resource_id=self.resource_id, - resource_type=MetadataResourceTypes(self.resource_type), - stack_component_id=self.stack_component_id, - ) - - return RunMetadataResponse( - id=self.id, - body=body, - metadata=metadata, - ) + # def to_model( + # self, + # include_metadata: bool = False, + # include_resources: bool = False, + # **kwargs: Any, + # ) -> "RunMetadataResponse": + # """Convert a `RunMetadataSchema` to a `RunMetadataResponse`. + # + # Args: + # include_metadata: Whether the metadata will be filled. + # include_resources: Whether the resources will be filled. + # **kwargs: Keyword arguments to allow schema specific logic + # + # + # Returns: + # The created `RunMetadataResponse`. + # """ + # body = RunMetadataResponseBody( + # user=self.user.to_model() if self.user else None, + # key=self.key, + # created=self.created, + # updated=self.updated, + # value=json.loads(self.value), + # type=MetadataTypeEnum(self.type), + # ) + # metadata = None + # if include_metadata: + # metadata = RunMetadataResponseMetadata( + # workspace=self.workspace.to_model(), + # resource_id=self.resource_id, + # resource_type=MetadataResourceTypes(self.resource_type), + # stack_component_id=self.stack_component_id, + # ) + # + # return RunMetadataResponse( + # id=self.id, + # body=body, + # metadata=metadata, + # ) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 3f95ec278a3..e9c98383a18 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -213,9 +213,7 @@ PipelineRunResponse, PipelineRunUpdate, PipelineUpdate, - RunMetadataFilter, RunMetadataRequest, - RunMetadataResponse, RunTemplateFilter, RunTemplateRequest, RunTemplateResponse, @@ -5271,9 +5269,7 @@ def count_runs(self, filter_model: Optional[PipelineRunFilter]) -> int: # ----------------------------- Run Metadata ----------------------------- - def create_run_metadata( - self, run_metadata: RunMetadataRequest - ) -> List[RunMetadataResponse]: + def create_run_metadata(self, run_metadata: RunMetadataRequest) -> None: """Creates run metadata. Args: @@ -5282,7 +5278,6 @@ def create_run_metadata( Returns: The created run metadata. """ - return_value: List[RunMetadataResponse] = [] with Session(self.engine) as session: for key, value in run_metadata.values.items(): type_ = run_metadata.types[key] @@ -5298,66 +5293,8 @@ def create_run_metadata( ) session.add(run_metadata_schema) session.commit() - return_value.append( - run_metadata_schema.to_model(include_metadata=True) - ) - return return_value - - def get_run_metadata( - self, run_metadata_id: UUID, hydrate: bool = True - ) -> RunMetadataResponse: - """Gets run metadata with the given ID. - Args: - run_metadata_id: The ID of the run metadata to get. - hydrate: Flag deciding whether to hydrate the output model(s) - by including metadata fields in the response. - - Returns: - The run metadata. - - Raises: - KeyError: if the run metadata doesn't exist. - """ - with Session(self.engine) as session: - run_metadata = session.exec( - select(RunMetadataSchema).where( - RunMetadataSchema.id == run_metadata_id - ) - ).first() - if run_metadata is None: - raise KeyError( - f"Unable to get run metadata with ID " - f"{run_metadata_id}: " - f"No run metadata with this ID found." - ) - return run_metadata.to_model(include_metadata=hydrate) - - def list_run_metadata( - self, - run_metadata_filter_model: RunMetadataFilter, - hydrate: bool = False, - ) -> Page[RunMetadataResponse]: - """List run metadata. - - Args: - run_metadata_filter_model: All filter parameters including - pagination params. - hydrate: Flag deciding whether to hydrate the output model(s) - by including metadata fields in the response. - - Returns: - The run metadata. - """ - with Session(self.engine) as session: - query = select(RunMetadataSchema) - return self.filter_and_paginate( - session=session, - query=query, - table=RunMetadataSchema, - filter_model=run_metadata_filter_model, - hydrate=hydrate, - ) + return None # ----------------------------- Schedules ----------------------------- diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 6f0cb7b496d..e4707b96f48 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -90,9 +90,7 @@ PipelineRunResponse, PipelineRunUpdate, PipelineUpdate, - RunMetadataFilter, RunMetadataRequest, - RunMetadataResponse, RunTemplateFilter, RunTemplateRequest, RunTemplateResponse, @@ -1620,52 +1618,14 @@ def get_or_create_run( # -------------------- Run metadata -------------------- @abstractmethod - def create_run_metadata( - self, run_metadata: RunMetadataRequest - ) -> List[RunMetadataResponse]: + def create_run_metadata(self, run_metadata: RunMetadataRequest) -> None: """Creates run metadata. Args: run_metadata: The run metadata to create. Returns: - The created run metadata. - """ - - @abstractmethod - def get_run_metadata( - self, run_metadata_id: UUID, hydrate: bool = True - ) -> RunMetadataResponse: - """Get run metadata by its unique ID. - - Args: - run_metadata_id: The ID of the run metadata to get. - hydrate: Flag deciding whether to hydrate the output model(s) - by including metadata fields in the response. - - Returns: - The run metadata with the given ID. - - Raises: - KeyError: if the run metadata doesn't exist. - """ - - @abstractmethod - def list_run_metadata( - self, - run_metadata_filter_model: RunMetadataFilter, - hydrate: bool = False, - ) -> Page[RunMetadataResponse]: - """List run metadata. - - Args: - run_metadata_filter_model: All filter parameters including - pagination params. - hydrate: Flag deciding whether to hydrate the output model(s) - by including metadata fields in the response. - - Returns: - The run metadata. + None """ # -------------------- Schedules -------------------- diff --git a/tests/integration/functional/models/test_artifact.py b/tests/integration/functional/models/test_artifact.py index e17351041ce..01e628108a5 100644 --- a/tests/integration/functional/models/test_artifact.py +++ b/tests/integration/functional/models/test_artifact.py @@ -27,10 +27,10 @@ from zenml.artifacts.utils import load_artifact_visualization from zenml.enums import ExecutionStatus from zenml.exceptions import EntityExistsError +from zenml.metadata.metadata_types import MetadataType from zenml.models import ( ArtifactVersionResponse, ArtifactVisualizationResponse, - RunMetadataResponse, ) if TYPE_CHECKING: @@ -341,7 +341,7 @@ def _get_visualizations_of_last_run( def _get_metadata_of_last_run( clean_client: "Client", -) -> Dict[str, "RunMetadataResponse"]: +) -> Dict[str, MetadataType]: """Get the artifact metadata of the last run.""" return _get_output_of_last_run(clean_client).run_metadata From 733a6c8b9f7bcd3becbfece2a41e3d7f51be1627 Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Thu, 17 Oct 2024 17:55:39 +0200 Subject: [PATCH 002/124] Removed last remnant of LazyLoader --- src/zenml/model/model.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/zenml/model/model.py b/src/zenml/model/model.py index 8af7cc595c6..7332ac7ac5c 100644 --- a/src/zenml/model/model.py +++ b/src/zenml/model/model.py @@ -358,18 +358,6 @@ def run_metadata(self) -> Dict[str, "MetadataType"]: Raises: RuntimeError: If the model version run metadata cannot be fetched. """ - from zenml.metadata.lazy_load import RunMetadataLazyGetter - - try: - get_pipeline_context() - # avoid exposing too much of internal details by keeping the return type - return RunMetadataLazyGetter( # type: ignore[return-value] - self.name, - self._lazy_version, - ) - except RuntimeError: - pass - response = self._get_or_create_model_version(hydrate=True) if response.run_metadata is None: raise RuntimeError( From ae71757cb248a8667004b2547dca6604dc94e0aa Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Fri, 18 Oct 2024 12:02:18 +0200 Subject: [PATCH 003/124] Reintroducing the lazy loaders. --- src/zenml/metadata/lazy_load.py | 67 ++++++++++++++++++++ src/zenml/model/model.py | 12 ++++ src/zenml/models/__init__.py | 2 + src/zenml/models/v2/core/artifact_version.py | 10 ++- src/zenml/models/v2/core/run_metadata.py | 37 ++++++++++- src/zenml/orchestrators/input_utils.py | 8 +-- src/zenml/steps/base_step.py | 9 +++ tests/integration/functional/test_client.py | 62 +++++++++++++----- 8 files changed, 182 insertions(+), 25 deletions(-) create mode 100644 src/zenml/metadata/lazy_load.py diff --git a/src/zenml/metadata/lazy_load.py b/src/zenml/metadata/lazy_load.py new file mode 100644 index 00000000000..4064450142a --- /dev/null +++ b/src/zenml/metadata/lazy_load.py @@ -0,0 +1,67 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Run Metadata Lazy Loader definition.""" + +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from zenml.models import RunMetadataResponse + + +class RunMetadataLazyGetter: + """Run Metadata Lazy Getter helper class. + + It serves the purpose to feed back to the user the metadata + lazy loader wrapper for any given key, if called inside a pipeline + design time context. + """ + + def __init__( + self, + _lazy_load_model_name: str, + _lazy_load_model_version: Optional[str], + _lazy_load_artifact_name: Optional[str] = None, + _lazy_load_artifact_version: Optional[str] = None, + ): + """Initialize a RunMetadataLazyGetter. + + Args: + _lazy_load_model_name: The model name. + _lazy_load_model_version: The model version. + _lazy_load_artifact_name: The artifact name. + _lazy_load_artifact_version: The artifact version. + """ + self._lazy_load_model_name = _lazy_load_model_name + self._lazy_load_model_version = _lazy_load_model_version + self._lazy_load_artifact_name = _lazy_load_artifact_name + self._lazy_load_artifact_version = _lazy_load_artifact_version + + def __getitem__(self, key: str) -> "RunMetadataResponse": + """Get the metadata for the given key. + + Args: + key: The metadata key. + + Returns: + The metadata lazy loader wrapper for the given key. + """ + from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse + + return LazyRunMetadataResponse( + lazy_load_model_name=self._lazy_load_model_name, + lazy_load_model_version=self._lazy_load_model_version, + lazy_load_artifact_name=self._lazy_load_artifact_name, + lazy_load_artifact_version=self._lazy_load_artifact_version, + lazy_load_metadata_name=key, + ) diff --git a/src/zenml/model/model.py b/src/zenml/model/model.py index 7332ac7ac5c..8af7cc595c6 100644 --- a/src/zenml/model/model.py +++ b/src/zenml/model/model.py @@ -358,6 +358,18 @@ def run_metadata(self) -> Dict[str, "MetadataType"]: Raises: RuntimeError: If the model version run metadata cannot be fetched. """ + from zenml.metadata.lazy_load import RunMetadataLazyGetter + + try: + get_pipeline_context() + # avoid exposing too much of internal details by keeping the return type + return RunMetadataLazyGetter( # type: ignore[return-value] + self.name, + self._lazy_version, + ) + except RuntimeError: + pass + response = self._get_or_create_model_version(hydrate=True) if response.run_metadata is None: raise RuntimeError( diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index 042ed8a1185..a330565c7bf 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -238,6 +238,7 @@ ) from zenml.models.v2.base.base_plugin_flavor import BasePluginFlavorResponse from zenml.models.v2.core.run_metadata import ( + LazyRunMetadataResponse, RunMetadataRequest, ) from zenml.models.v2.core.schedule import ( @@ -411,6 +412,7 @@ FlavorResponseBody.model_rebuild() FlavorResponseMetadata.model_rebuild() LazyArtifactVersionResponse.model_rebuild() +LazyRunMetadataResponse.model_rebuild() ModelResponseBody.model_rebuild() ModelResponseMetadata.model_rebuild() ModelVersionResponseBody.model_rebuild() diff --git a/src/zenml/models/v2/core/artifact_version.py b/src/zenml/models/v2/core/artifact_version.py index 837e7e08c4c..45ac2432117 100644 --- a/src/zenml/models/v2/core/artifact_version.py +++ b/src/zenml/models/v2/core/artifact_version.py @@ -636,5 +636,11 @@ def run_metadata(self) -> Dict[str, MetadataType]: Returns: getter of lazy responses for internal use. """ - # todo: figure this out - pass + from zenml.metadata.lazy_load import RunMetadataLazyGetter + + return RunMetadataLazyGetter( # type: ignore[return-value] + self.lazy_load_model_name, + self.lazy_load_model_version, + self.lazy_load_name, + self.lazy_load_version, + ) diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index c4a2ef8e678..7fd68d4a713 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -21,7 +21,7 @@ from zenml.enums import MetadataResourceTypes from zenml.metadata.metadata_types import MetadataType, MetadataTypeEnum from zenml.models.v2.base.scoped import ( - WorkspaceScopedRequest, + WorkspaceScopedRequest, WorkspaceScopedResponse ) # ------------------ Request Model ------------------ @@ -45,3 +45,38 @@ class RunMetadataRequest(WorkspaceScopedRequest): types: Dict[str, "MetadataTypeEnum"] = Field( title="The types of the metadata to be created.", ) + + +class LazyRunMetadataResponse(WorkspaceScopedResponse): + """Lazy run metadata response. + + Used if the run metadata is accessed from the model in + a pipeline context available only during pipeline compilation. + """ + + id: Optional[UUID] = None # type: ignore[assignment] + lazy_load_artifact_name: Optional[str] = None + lazy_load_artifact_version: Optional[str] = None + lazy_load_metadata_name: Optional[str] = None + lazy_load_model_name: str + lazy_load_model_version: Optional[str] = None + + def get_body(self) -> None: # type: ignore[override] + """Protects from misuse of the lazy loader. + + Raises: + RuntimeError: always + """ + raise RuntimeError( + "Cannot access run metadata body before pipeline runs." + ) + + def get_metadata(self) -> None: # type: ignore[override] + """Protects from misuse of the lazy loader. + + Raises: + RuntimeError: always + """ + raise RuntimeError( + "Cannot access run metadata metadata before pipeline runs." + ) diff --git a/src/zenml/orchestrators/input_utils.py b/src/zenml/orchestrators/input_utils.py index 9d3c72d427e..0d596f7e498 100644 --- a/src/zenml/orchestrators/input_utils.py +++ b/src/zenml/orchestrators/input_utils.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Utilities for inputs.""" -from typing import TYPE_CHECKING, Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple, get_args, Union from uuid import UUID from zenml.client import Client @@ -101,7 +101,7 @@ def resolve_step_inputs( step.config.parameters[name] = ( context_model_version.run_metadata[ config_.metadata_name - ].value + ] ) elif config_.artifact_name is None: err_msg = ( @@ -120,7 +120,7 @@ def resolve_step_inputs( step.config.parameters[name] = ( artifact_.run_metadata[ config_.metadata_name - ].value + ] ) except KeyError: err_msg = ( @@ -143,8 +143,6 @@ def resolve_step_inputs( value_ = cll_.evaluate() if isinstance(value_, ArtifactVersionResponse): input_artifacts[name] = value_ - elif isinstance(value_, MetadataType): - step.config.parameters[name] = value_.value else: step.config.parameters[name] = value_ diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index efea89a606f..2a3d324608c 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -330,6 +330,7 @@ def _parse_call_args( from zenml.models.v2.core.artifact_version import ( LazyArtifactVersionResponse, ) + from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse signature = inspect.signature(self.entrypoint, follow_wrapped=True) @@ -377,6 +378,14 @@ def _parse_call_args( artifact_version=value.lazy_load_version, metadata_name=None, ) + elif isinstance(value, LazyRunMetadataResponse): + model_artifacts_or_metadata[key] = ModelVersionDataLazyLoader( + model_name=value.lazy_load_model_name, + model_version=value.lazy_load_model_version, + artifact_name=value.lazy_load_artifact_name, + artifact_version=value.lazy_load_artifact_version, + metadata_name=value.lazy_load_metadata_name, + ) elif isinstance(value, ClientLazyLoader): client_lazy_loaders[key] = value else: diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index 0bd4523d796..5c05b40b298 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -20,6 +20,7 @@ from uuid import uuid4 import pytest +from mypy.types import names from pydantic import BaseModel from typing_extensions import Annotated @@ -33,7 +34,7 @@ log_artifact_metadata, pipeline, save_artifact, - step, + step, get_pipeline_context, log_model_metadata, get_step_context, ) from zenml.client import Client from zenml.config.source import Source @@ -1095,9 +1096,34 @@ def test_basic_crud_for_entity( # This means the test already succeeded and deleted the entity, # nothing to do here pass - - -@step +# +# +# @step +# def lazy_producer_test_artifact() -> Annotated[str, "new_one"]: +# """Produce artifact with metadata.""" +# from zenml.client import Client +# +# log_artifact_metadata(metadata={"some_meta": "meta_new_one"}) +# +# client = Client() +# model = client.create_model(name="model_name", description="model_desc") +# client.create_model_version( +# model_name_or_id=model.id, +# name="model_version", +# description="mv_desc_1", +# ) +# mv = client.create_model_version( +# model_name_or_id=model.id, +# name="model_version2", +# description="mv_desc_2", +# ) +# client.update_model_version( +# model_name_or_id=model.id, version_name_or_id=mv.id, stage="staging" +# ) +# return "body_new_one" + + +@step() def lazy_producer_test_artifact() -> Annotated[str, "new_one"]: """Produce artifact with metadata.""" from zenml.client import Client @@ -1105,19 +1131,14 @@ def lazy_producer_test_artifact() -> Annotated[str, "new_one"]: log_artifact_metadata(metadata={"some_meta": "meta_new_one"}) client = Client() - model = client.create_model(name="model_name", description="model_desc") - client.create_model_version( - model_name_or_id=model.id, - name="model_version", - description="mv_desc_1", - ) - mv = client.create_model_version( - model_name_or_id=model.id, - name="model_version2", - description="mv_desc_2", + + log_model_metadata( + metadata={"some_meta": "meta_new_one"}, ) + model = get_step_context().model + client.update_model_version( - model_name_or_id=model.id, version_name_or_id=mv.id, stage="staging" + model_name_or_id=model.name, version_name_or_id=model.version, stage="staging" ) return "body_new_one" @@ -1206,7 +1227,7 @@ def test_pipeline_can_load_in_lazy_mode( ): """Tests that user can load model artifact versions, metadata and models (versions) in lazy mode in pipeline codes.""" - @pipeline(enable_cache=False) + @pipeline(enable_cache=False, model=Model(name="aria", version="new")) def dummy(): artifact_existing = clean_client.get_artifact_version( name_id_or_prefix="preexisting" @@ -1222,6 +1243,8 @@ def dummy(): model = clean_client.get_model(model_name_or_id="model_name") + lz2 = get_pipeline_context().model.run_metadata["some_meta"] + lazy_producer_test_artifact() lazy_asserter_test_artifact( # load artifact directly @@ -1231,7 +1254,7 @@ def dummy(): # pass as artifact response artifact_new, # read value of metadata directly - artifact_metadata_new.value, + artifact_metadata_new, # load model model, # load model version by version @@ -1257,6 +1280,11 @@ def dummy(): artifact_name="preexisting", artifact_version="1.2.3", ) + log_model_metadata( + metadata={"some_meta": "meta_preexisting"}, + model_name="aria", + model_version="new" + ) with pytest.raises(KeyError): clean_client.get_artifact_version("new_one") dummy() From 7d0ff8246c88a23566f8dc3714ce53803afce80b Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Fri, 18 Oct 2024 13:50:30 +0200 Subject: [PATCH 004/124] Add LazyRunMetadataResponse to EntrypointFunctionDefinition --- src/zenml/steps/entrypoint_function_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/zenml/steps/entrypoint_function_utils.py b/src/zenml/steps/entrypoint_function_utils.py index 733985d1174..997d639d8ee 100644 --- a/src/zenml/steps/entrypoint_function_utils.py +++ b/src/zenml/steps/entrypoint_function_utils.py @@ -32,6 +32,7 @@ from zenml.exceptions import StepInterfaceError from zenml.logger import get_logger from zenml.materializers.base_materializer import BaseMaterializer +from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse from zenml.steps.utils import ( OutputSignature, parse_return_type_annotations, @@ -152,6 +153,7 @@ def validate_input(self, key: str, value: Any) -> None: ExternalArtifact, ArtifactVersionResponse, ClientLazyLoader, + LazyRunMetadataResponse, ), ): # If we were to do any type validation for artifacts here, we From d7a9f8396830f475d31aeb4983e2b7aaf5d66537 Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Fri, 18 Oct 2024 14:05:34 +0200 Subject: [PATCH 005/124] Test for lazy loaders works now --- src/zenml/cli/model.py | 2 +- src/zenml/models/v2/core/run_metadata.py | 3 +- src/zenml/orchestrators/input_utils.py | 11 +-- .../zen_stores/schemas/artifact_schemas.py | 5 +- src/zenml/zen_stores/schemas/model_schemas.py | 5 +- .../schemas/pipeline_run_schemas.py | 2 +- tests/integration/functional/test_client.py | 71 +++++++++---------- 7 files changed, 47 insertions(+), 52 deletions(-) diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index ab9355abe76..f03de321144 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -62,7 +62,7 @@ def _model_version_to_print( run_metadata = None if model_version.run_metadata: run_metadata = { - k: v.value for k, v in model_version.run_metadata.items() + k: v for k, v in model_version.run_metadata.items() } return { "id": model_version.id, diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index 7fd68d4a713..366f12edbb1 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -21,7 +21,8 @@ from zenml.enums import MetadataResourceTypes from zenml.metadata.metadata_types import MetadataType, MetadataTypeEnum from zenml.models.v2.base.scoped import ( - WorkspaceScopedRequest, WorkspaceScopedResponse + WorkspaceScopedRequest, + WorkspaceScopedResponse, ) # ------------------ Request Model ------------------ diff --git a/src/zenml/orchestrators/input_utils.py b/src/zenml/orchestrators/input_utils.py index 0d596f7e498..fb30f68a1a8 100644 --- a/src/zenml/orchestrators/input_utils.py +++ b/src/zenml/orchestrators/input_utils.py @@ -13,13 +13,12 @@ # permissions and limitations under the License. """Utilities for inputs.""" -from typing import TYPE_CHECKING, Dict, List, Tuple, get_args, Union +from typing import TYPE_CHECKING, Dict, List, Tuple from uuid import UUID from zenml.client import Client from zenml.config.step_configurations import Step from zenml.exceptions import InputResolutionError -from zenml.metadata.metadata_types import MetadataType from zenml.utils import pagination_utils if TYPE_CHECKING: @@ -99,9 +98,7 @@ def resolve_step_inputs( ): # metadata values should go directly in parameters, as primitive types step.config.parameters[name] = ( - context_model_version.run_metadata[ - config_.metadata_name - ] + context_model_version.run_metadata[config_.metadata_name] ) elif config_.artifact_name is None: err_msg = ( @@ -118,9 +115,7 @@ def resolve_step_inputs( # metadata values should go directly in parameters, as primitive types try: step.config.parameters[name] = ( - artifact_.run_metadata[ - config_.metadata_name - ] + artifact_.run_metadata[config_.metadata_name] ) except KeyError: err_msg = ( diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index a158f6e5c97..8efb48aced4 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -13,6 +13,7 @@ # permissions and limitations under the License. """SQLModel implementation of artifact table.""" +import json from datetime import datetime from typing import TYPE_CHECKING, Any, List, Optional from uuid import UUID @@ -351,7 +352,9 @@ def to_model( artifact_store_id=self.artifact_store_id, producer_step_run_id=producer_step_run_id, visualizations=[v.to_model() for v in self.visualizations], - run_metadata={m.key: m.value for m in self.run_metadata}, + run_metadata={ + m.key: json.loads(m.value) for m in self.run_metadata + }, ) resources = None diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 15737ec6613..d2b09548f93 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -13,6 +13,7 @@ # permissions and limitations under the License. """SQLModel implementation of model tables.""" +import json from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast from uuid import UUID @@ -378,7 +379,9 @@ def to_model( metadata = ModelVersionResponseMetadata( workspace=self.workspace.to_model(), description=self.description, - run_metadata={rm.key: rm.value for rm in self.run_metadata}, + run_metadata={ + rm.key: json.loads(rm.value) for rm in self.run_metadata + }, ) resources = None diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 13b6c8caa92..66312f4a351 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -271,7 +271,7 @@ def to_model( ) run_metadata = { - metadata_schema.key: metadata_schema.value + metadata_schema.key: json.loads(metadata_schema.value) for metadata_schema in self.run_metadata } diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index 5c05b40b298..23983031e40 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -20,7 +20,6 @@ from uuid import uuid4 import pytest -from mypy.types import names from pydantic import BaseModel from typing_extensions import Annotated @@ -31,10 +30,13 @@ from tests.integration.functional.utils import sample_name from zenml import ( ExternalArtifact, + get_pipeline_context, + get_step_context, log_artifact_metadata, + log_model_metadata, pipeline, save_artifact, - step, get_pipeline_context, log_model_metadata, get_step_context, + step, ) from zenml.client import Client from zenml.config.source import Source @@ -1096,34 +1098,9 @@ def test_basic_crud_for_entity( # This means the test already succeeded and deleted the entity, # nothing to do here pass -# -# -# @step -# def lazy_producer_test_artifact() -> Annotated[str, "new_one"]: -# """Produce artifact with metadata.""" -# from zenml.client import Client -# -# log_artifact_metadata(metadata={"some_meta": "meta_new_one"}) -# -# client = Client() -# model = client.create_model(name="model_name", description="model_desc") -# client.create_model_version( -# model_name_or_id=model.id, -# name="model_version", -# description="mv_desc_1", -# ) -# mv = client.create_model_version( -# model_name_or_id=model.id, -# name="model_version2", -# description="mv_desc_2", -# ) -# client.update_model_version( -# model_name_or_id=model.id, version_name_or_id=mv.id, stage="staging" -# ) -# return "body_new_one" - - -@step() + + +@step def lazy_producer_test_artifact() -> Annotated[str, "new_one"]: """Produce artifact with metadata.""" from zenml.client import Client @@ -1135,10 +1112,16 @@ def lazy_producer_test_artifact() -> Annotated[str, "new_one"]: log_model_metadata( metadata={"some_meta": "meta_new_one"}, ) + model = get_step_context().model + mv = client.create_model_version( + model_name_or_id=model.name, + name="model_version2", + description="mv_desc_2", + ) client.update_model_version( - model_name_or_id=model.name, version_name_or_id=model.version, stage="staging" + model_name_or_id=model.name, version_name_or_id=mv.id, stage="staging" ) return "body_new_one" @@ -1152,6 +1135,7 @@ def lazy_asserter_test_artifact( model: ModelResponse, model_version_by_version: ModelVersionResponse, model_version_by_stage: ModelVersionResponse, + model_version_run_metadata: str, ): """Assert that passed in values are loaded in lazy mode. They do not exists before actual run of the pipeline. @@ -1161,12 +1145,13 @@ def lazy_asserter_test_artifact( assert artifact_new == "body_new_one" assert artifact_metadata_new == "meta_new_one" - assert model.name == "model_name" - assert model.description == "model_desc" + assert model.name == "aria" + # assert model.description == "model_description" assert model_version_by_version.name == "model_version" - assert model_version_by_version.description == "mv_desc_1" + # assert model_version_by_version.description == "mv_desc_1" assert model_version_by_stage.name == "model_version2" assert model_version_by_stage.description == "mv_desc_2" + assert model_version_run_metadata == "meta_new_one" class TestArtifact: @@ -1227,7 +1212,12 @@ def test_pipeline_can_load_in_lazy_mode( ): """Tests that user can load model artifact versions, metadata and models (versions) in lazy mode in pipeline codes.""" - @pipeline(enable_cache=False, model=Model(name="aria", version="new")) + @pipeline( + enable_cache=False, + model=Model( + name="aria", version="model_version", description="mv_desc_1" + ), + ) def dummy(): artifact_existing = clean_client.get_artifact_version( name_id_or_prefix="preexisting" @@ -1241,9 +1231,11 @@ def dummy(): ) artifact_metadata_new = artifact_new.run_metadata["some_meta"] - model = clean_client.get_model(model_name_or_id="model_name") + model = clean_client.get_model(model_name_or_id="aria") - lz2 = get_pipeline_context().model.run_metadata["some_meta"] + model_version_run_metadata = ( + get_pipeline_context().model.run_metadata["some_meta"] + ) lazy_producer_test_artifact() lazy_asserter_test_artifact( @@ -1266,9 +1258,10 @@ def dummy(): # load model version by stage clean_client.get_model_version( # this can be lazy loaders too - model.id, + model_name_or_id=model.id, model_version_name_or_number_or_id="staging", ), + model_version_run_metadata, after=["lazy_producer_test_artifact"], ) @@ -1283,7 +1276,7 @@ def dummy(): log_model_metadata( metadata={"some_meta": "meta_preexisting"}, model_name="aria", - model_version="new" + model_version="model_version", ) with pytest.raises(KeyError): clean_client.get_artifact_version("new_one") From 9a0e0b29ca4e2ec97c30a7816e618cdb2505a350 Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Mon, 21 Oct 2024 08:38:01 +0200 Subject: [PATCH 006/124] Fixed tests, reformatted --- src/zenml/cli/model.py | 4 +--- .../functional/artifacts/test_utils.py | 24 +++++++++---------- .../functional/model/test_model_version.py | 18 +++++++------- .../functional/steps/test_step_context.py | 6 ++--- .../functional/steps/test_utils.py | 12 +++++----- 5 files changed, 31 insertions(+), 33 deletions(-) diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index f03de321144..d8ea4534bfd 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -61,9 +61,7 @@ def _model_version_to_print( ) -> Dict[str, Any]: run_metadata = None if model_version.run_metadata: - run_metadata = { - k: v for k, v in model_version.run_metadata.items() - } + run_metadata = {k: v for k, v in model_version.run_metadata.items()} return { "id": model_version.id, "model": model_version.model.name, diff --git a/tests/integration/functional/artifacts/test_utils.py b/tests/integration/functional/artifacts/test_utils.py index a084f935fbb..78d67d3e6da 100644 --- a/tests/integration/functional/artifacts/test_utils.py +++ b/tests/integration/functional/artifacts/test_utils.py @@ -150,22 +150,22 @@ def test_log_artifact_metadata_existing(clean_client): "meaning_of_life", version="1" ) assert "description" in artifact_1.run_metadata - assert artifact_1.run_metadata["description"].value == "Aria is great!" + assert artifact_1.run_metadata["description"] == "Aria is great!" assert "description_3" in artifact_1.run_metadata - assert artifact_1.run_metadata["description_3"].value == "Axl is great!" + assert artifact_1.run_metadata["description_3"] == "Axl is great!" assert "float" in artifact_1.run_metadata - assert artifact_1.run_metadata["float"].value - 1.0 < 10e-6 + assert artifact_1.run_metadata["float"] - 1.0 < 10e-6 assert "int" in artifact_1.run_metadata - assert artifact_1.run_metadata["int"].value == 1 + assert artifact_1.run_metadata["int"] == 1 assert "str" in artifact_1.run_metadata - assert artifact_1.run_metadata["str"].value == "1.0" + assert artifact_1.run_metadata["str"] == "1.0" assert "list_str" in artifact_1.run_metadata assert ( - len(set(artifact_1.run_metadata["list_str"].value) - {"1.0", "2.0"}) + len(set(artifact_1.run_metadata["list_str"]) - {"1.0", "2.0"}) == 0 ) assert "list_floats" in artifact_1.run_metadata - for each in artifact_1.run_metadata["list_floats"].value: + for each in artifact_1.run_metadata["list_floats"]: if 0.99 < each < 1.01: assert each - 1.0 < 10e-6 else: @@ -175,7 +175,7 @@ def test_log_artifact_metadata_existing(clean_client): "meaning_of_life", version="43" ) assert "description_2" in artifact_2.run_metadata - assert artifact_2.run_metadata["description_2"].value == "Blupus is great!" + assert artifact_2.run_metadata["description_2"] == "Blupus is great!" @step @@ -200,9 +200,9 @@ def artifact_metadata_logging_pipeline(): run_ = artifact_metadata_logging_pipeline.model.last_run output = run_.steps["artifact_metadata_logging_step"].output assert "description" in output.run_metadata - assert output.run_metadata["description"].value == "Aria is great!" + assert output.run_metadata["description"] == "Aria is great!" assert "metrics" in output.run_metadata - assert output.run_metadata["metrics"].value == {"accuracy": 0.9} + assert output.run_metadata["metrics"] == {"accuracy": 0.9} @step @@ -233,9 +233,9 @@ def artifact_metadata_logging_pipeline(): assert "metrics" not in str_output.run_metadata int_output = step_.outputs["int_output"] assert "description" in int_output.run_metadata - assert int_output.run_metadata["description"].value == "Blupus is great!" + assert int_output.run_metadata["description"] == "Blupus is great!" assert "metrics" in int_output.run_metadata - assert int_output.run_metadata["metrics"].value == {"accuracy": 0.9} + assert int_output.run_metadata["metrics"] == {"accuracy": 0.9} @step diff --git a/tests/integration/functional/model/test_model_version.py b/tests/integration/functional/model/test_model_version.py index b9a6b3c7b54..43ae5d4ca21 100644 --- a/tests/integration/functional/model/test_model_version.py +++ b/tests/integration/functional/model/test_model_version.py @@ -108,7 +108,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback): def step_metadata_logging_functional(mdl_name: str): """Functional logging using implicit Model from context.""" log_model_metadata({"foo": "bar"}) - assert get_step_context().model.run_metadata["foo"].value == "bar" + assert get_step_context().model.run_metadata["foo"] == "bar" log_model_metadata( {"foo": "bar"}, model_name=mdl_name, model_version="other" ) @@ -393,13 +393,13 @@ def test_metadata_logging(self): mv.log_metadata({"foo": "bar"}) assert len(mv.run_metadata) == 1 - assert mv.run_metadata["foo"].value == "bar" + assert mv.run_metadata["foo"] == "bar" mv.log_metadata({"bar": "foo"}) assert len(mv.run_metadata) == 2 - assert mv.run_metadata["foo"].value == "bar" - assert mv.run_metadata["bar"].value == "foo" + assert mv.run_metadata["foo"] == "bar" + assert mv.run_metadata["bar"] == "foo" def test_metadata_logging_functional(self): """Test that model version can be used to track metadata from function.""" @@ -415,7 +415,7 @@ def test_metadata_logging_functional(self): ) assert len(mv.run_metadata) == 1 - assert mv.run_metadata["foo"].value == "bar" + assert mv.run_metadata["foo"] == "bar" with pytest.raises(ValueError): log_model_metadata({"foo": "bar"}) @@ -425,8 +425,8 @@ def test_metadata_logging_functional(self): ) assert len(mv.run_metadata) == 2 - assert mv.run_metadata["foo"].value == "bar" - assert mv.run_metadata["bar"].value == "foo" + assert mv.run_metadata["foo"] == "bar" + assert mv.run_metadata["bar"] == "foo" def test_metadata_logging_in_steps(self): """Test that model version can be used to track metadata from function in steps.""" @@ -449,11 +449,11 @@ def my_pipeline(): mv = Model(name=mdl_name, version="context") assert len(mv.run_metadata) == 1 - assert mv.run_metadata["foo"].value == "bar" + assert mv.run_metadata["foo"] == "bar" mv = Model(name=mdl_name, version="other") assert len(mv.run_metadata) == 1 - assert mv.run_metadata["foo"].value == "bar" + assert mv.run_metadata["foo"] == "bar" @pytest.mark.parametrize("delete_artifacts", [False, True]) def test_deletion_of_links(self, delete_artifacts: bool): diff --git a/tests/integration/functional/steps/test_step_context.py b/tests/integration/functional/steps/test_step_context.py index 8ce7b447cc5..ad1352f0381 100644 --- a/tests/integration/functional/steps/test_step_context.py +++ b/tests/integration/functional/steps/test_step_context.py @@ -100,7 +100,7 @@ def output_metadata_logging_step() -> Annotated[int, "my_output"]: def step_context_metadata_reader_step(my_input: int) -> None: step_context = get_step_context() my_input_metadata = step_context.inputs["my_input"].run_metadata - assert my_input_metadata["some_key"].value == "some_value" + assert my_input_metadata["some_key"] == "some_value" def test_input_artifacts_property(): @@ -205,10 +205,10 @@ def _pipeline(): artifact = clean_client.get_artifact(full_name) for k, v in metadata.items(): assert k in av.run_metadata - assert av.run_metadata[k].value == v + assert av.run_metadata[k] == v if full_name == "custom_name": - assert av.run_metadata["config_metadata"].value == "bar" + assert av.run_metadata["config_metadata"] == "bar" assert {t.name for t in av.tags} == set(tags).union({"config_tags"}) assert {t.name for t in artifact.tags} == set(tags).union( {"config_tags"} diff --git a/tests/integration/functional/steps/test_utils.py b/tests/integration/functional/steps/test_utils.py index 539ff75520e..7bdff4867e9 100644 --- a/tests/integration/functional/steps/test_utils.py +++ b/tests/integration/functional/steps/test_utils.py @@ -49,9 +49,9 @@ def step_metadata_logging_pipeline(): "step_metadata_logging_step_inside_run" ].run_metadata assert "description" in run_metadata - assert run_metadata["description"].value == "Aria is great!" + assert run_metadata["description"] == "Aria is great!" assert "metrics" in run_metadata - assert run_metadata["metrics"].value == {"accuracy": 0.9} + assert run_metadata["metrics"] == {"accuracy": 0.9} def test_log_step_metadata_using_latest_run(clean_client): @@ -84,9 +84,9 @@ def step_metadata_logging_pipeline(): "step_metadata_logging_step" ].run_metadata assert "description" in run_metadata_after_log - assert run_metadata_after_log["description"].value == "Axl is great!" + assert run_metadata_after_log["description"] == "Axl is great!" assert "metrics" in run_metadata_after_log - assert run_metadata_after_log["metrics"].value == {"accuracy": 0.9} + assert run_metadata_after_log["metrics"] == {"accuracy": 0.9} def test_log_step_metadata_using_specific_params(clean_client): @@ -124,6 +124,6 @@ def step_metadata_logging_pipeline(): "step_metadata_logging_step" ].run_metadata assert "description" in run_metadata_after_log - assert run_metadata_after_log["description"].value == "Blupus is great!" + assert run_metadata_after_log["description"] == "Blupus is great!" assert "metrics" in run_metadata_after_log - assert run_metadata_after_log["metrics"].value == {"accuracy": 0.9} + assert run_metadata_after_log["metrics"] == {"accuracy": 0.9} From 145b90b81ba9b0ce67a90e6c2cff6dac267ce4d3 Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Mon, 21 Oct 2024 09:47:59 +0200 Subject: [PATCH 007/124] Use updated template --- .github/workflows/update-templates-to-examples.yml | 2 +- src/zenml/cli/base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/update-templates-to-examples.yml b/.github/workflows/update-templates-to-examples.yml index 9b55d1c9928..db47166a00e 100644 --- a/.github/workflows/update-templates-to-examples.yml +++ b/.github/workflows/update-templates-to-examples.yml @@ -189,7 +189,7 @@ jobs: python-version: ${{ inputs.python-version }} stack-name: local ref-zenml: ${{ github.ref }} - ref-template: 2024.09.24 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py + ref-template: 2024.10.21 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py - name: Clean-up run: | rm -rf ./local_checkout diff --git a/src/zenml/cli/base.py b/src/zenml/cli/base.py index ecbb01d3607..98fd902cc78 100644 --- a/src/zenml/cli/base.py +++ b/src/zenml/cli/base.py @@ -83,7 +83,7 @@ def copier_github_url(self) -> str: ), starter=ZenMLProjectTemplateLocation( github_url="zenml-io/template-starter", - github_tag="2024.09.24", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + github_tag="2024.10.21", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), nlp=ZenMLProjectTemplateLocation( github_url="zenml-io/template-nlp", From 1e1991ad83700557f6cefa759e846c264d0056ba Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Mon, 21 Oct 2024 07:50:39 +0000 Subject: [PATCH 008/124] Auto-update of Starter template --- examples/mlops_starter/.copier-answers.yml | 2 +- examples/mlops_starter/quickstart.ipynb | 4 ++-- examples/mlops_starter/run.py | 4 ++-- examples/mlops_starter/steps/model_promoter.py | 8 +++----- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/mlops_starter/.copier-answers.yml b/examples/mlops_starter/.copier-answers.yml index 8b1fb8187ed..e17f27ee551 100644 --- a/examples/mlops_starter/.copier-answers.yml +++ b/examples/mlops_starter/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.09.24 +_commit: 2024.10.21 _src_path: gh:zenml-io/template-starter email: info@zenml.io full_name: ZenML GmbH diff --git a/examples/mlops_starter/quickstart.ipynb b/examples/mlops_starter/quickstart.ipynb index df8c010b5ea..6fba7a0e8cc 100644 --- a/examples/mlops_starter/quickstart.ipynb +++ b/examples/mlops_starter/quickstart.ipynb @@ -994,8 +994,8 @@ "@pipeline\n", "def inference(preprocess_pipeline_id: UUID):\n", " \"\"\"Model batch inference pipeline\"\"\"\n", - " # random_state = client.get_artifact_version(name_id_or_prefix=preprocess_pipeline_id).metadata[\"random_state\"].value\n", - " # target = client.get_artifact_version(name_id_or_prefix=preprocess_pipeline_id).run_metadata['target'].value\n", + " # random_state = client.get_artifact_version(name_id_or_prefix=preprocess_pipeline_id).metadata[\"random_state\"]\n", + " # target = client.get_artifact_version(name_id_or_prefix=preprocess_pipeline_id).run_metadata['target']\n", " random_state = 42\n", " target = \"target\"\n", "\n", diff --git a/examples/mlops_starter/run.py b/examples/mlops_starter/run.py index d7b1a7f11b2..16a352588d6 100644 --- a/examples/mlops_starter/run.py +++ b/examples/mlops_starter/run.py @@ -239,8 +239,8 @@ def main( # to get the random state and target column random_state = preprocess_pipeline_artifact.run_metadata[ "random_state" - ].value - target = preprocess_pipeline_artifact.run_metadata["target"].value + ] + target = preprocess_pipeline_artifact.run_metadata["target"] run_args_inference["random_state"] = random_state run_args_inference["target"] = target diff --git a/examples/mlops_starter/steps/model_promoter.py b/examples/mlops_starter/steps/model_promoter.py index 52040638496..43d43ceac1f 100644 --- a/examples/mlops_starter/steps/model_promoter.py +++ b/examples/mlops_starter/steps/model_promoter.py @@ -58,11 +58,9 @@ def model_promoter(accuracy: float, stage: str = "production") -> bool: try: stage_model = client.get_model_version(current_model.name, stage) # We compare their metrics - prod_accuracy = ( - stage_model.get_artifact("sklearn_classifier") - .run_metadata["test_accuracy"] - .value - ) + prod_accuracy = stage_model.get_artifact( + "sklearn_classifier" + ).run_metadata["test_accuracy"] if float(accuracy) > float(prod_accuracy): # If current model has better metrics, we promote it is_promoted = True From d83628aeab33fe5cc9375578b36079e461ef47d7 Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Mon, 21 Oct 2024 10:05:41 +0200 Subject: [PATCH 009/124] Updated more templates --- .github/workflows/update-templates-to-examples.yml | 4 ++-- src/zenml/cli/base.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/update-templates-to-examples.yml b/.github/workflows/update-templates-to-examples.yml index db47166a00e..327e2d45934 100644 --- a/.github/workflows/update-templates-to-examples.yml +++ b/.github/workflows/update-templates-to-examples.yml @@ -46,7 +46,7 @@ jobs: python-version: ${{ inputs.python-version }} stack-name: local ref-zenml: ${{ github.ref }} - ref-template: 2024.10.10 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py + ref-template: 2024.10.21 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py - name: Clean-up run: | rm -rf ./local_checkout @@ -118,7 +118,7 @@ jobs: python-version: ${{ inputs.python-version }} stack-name: local ref-zenml: ${{ github.ref }} - ref-template: 2024.09.23 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py + ref-template: 2024.10.21 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py - name: Clean-up run: | rm -rf ./local_checkout diff --git a/src/zenml/cli/base.py b/src/zenml/cli/base.py index 98fd902cc78..586b4081c2c 100644 --- a/src/zenml/cli/base.py +++ b/src/zenml/cli/base.py @@ -79,7 +79,7 @@ def copier_github_url(self) -> str: ZENML_PROJECT_TEMPLATES = dict( e2e_batch=ZenMLProjectTemplateLocation( github_url="zenml-io/template-e2e-batch", - github_tag="2024.10.10", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + github_tag="2024.10.21", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), starter=ZenMLProjectTemplateLocation( github_url="zenml-io/template-starter", @@ -87,7 +87,7 @@ def copier_github_url(self) -> str: ), nlp=ZenMLProjectTemplateLocation( github_url="zenml-io/template-nlp", - github_tag="2024.09.23", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + github_tag="2024.10.21", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), llm_finetuning=ZenMLProjectTemplateLocation( github_url="zenml-io/template-llm-finetuning", From c4febf312192fb01e3360cea23fbcc9dd7823f58 Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Mon, 21 Oct 2024 10:32:56 +0200 Subject: [PATCH 010/124] Fixed failing test --- .../functional/artifacts/test_utils.py | 5 +--- .../functional/zen_stores/test_zen_store.py | 23 +++++++++++-------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/integration/functional/artifacts/test_utils.py b/tests/integration/functional/artifacts/test_utils.py index 78d67d3e6da..c2319c16233 100644 --- a/tests/integration/functional/artifacts/test_utils.py +++ b/tests/integration/functional/artifacts/test_utils.py @@ -160,10 +160,7 @@ def test_log_artifact_metadata_existing(clean_client): assert "str" in artifact_1.run_metadata assert artifact_1.run_metadata["str"] == "1.0" assert "list_str" in artifact_1.run_metadata - assert ( - len(set(artifact_1.run_metadata["list_str"]) - {"1.0", "2.0"}) - == 0 - ) + assert len(set(artifact_1.run_metadata["list_str"]) - {"1.0", "2.0"}) == 0 assert "list_floats" in artifact_1.run_metadata for each in artifact_1.run_metadata["list_floats"]: if 0.99 < each < 1.01: diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 64269a7f549..2d6f707e630 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -5449,12 +5449,20 @@ def test_metadata_full_cycle_with_cascade_deletion( else None, ) ) - rm = client.zen_store.get_run_metadata(rm[0].id, True) - assert rm.key == "foo" - assert rm.value == "bar" - assert rm.resource_id == resource.id - assert rm.resource_type == type_ - assert rm.type == MetadataTypeEnum.STRING + if type_ == MetadataResourceTypes.PIPELINE_RUN: + rm = client.zen_store.get_run(resource.id, True).metadata + assert rm.key == "foo" + assert rm.value == "bar" + assert rm.resource_id == resource.id + assert rm.resource_type == type_ + assert rm.type == MetadataTypeEnum.STRING + elif type_ == MetadataResourceTypes.STEP_RUN: + rm = client.zen_store.get_run_step(resource.id, True).metadata + assert rm.key == "foo" + assert rm.value == "bar" + assert rm.resource_id == resource.id + assert rm.resource_type == type_ + assert rm.type == MetadataTypeEnum.STRING if type_ == MetadataResourceTypes.ARTIFACT_VERSION: client.zen_store.delete_artifact_version(resource.id) @@ -5468,9 +5476,6 @@ def test_metadata_full_cycle_with_cascade_deletion( client.zen_store.delete_run(pr.id) client.zen_store.delete_deployment(deployment.id) - with pytest.raises(KeyError): - client.zen_store.get_run_metadata(rm.id) - client.zen_store.delete_stack_component(sc.id) From 5aef8ab386c2e544f2fe4c5330e7dacba75f1763 Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Mon, 21 Oct 2024 10:38:00 +0200 Subject: [PATCH 011/124] Fixed step run schemas --- src/zenml/zen_stores/schemas/step_run_schemas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index 07812b26ec4..7917fd01d9d 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -214,7 +214,7 @@ def to_model( or a step_configuration. """ run_metadata = { - metadata_schema.key: metadata_schema.to_model() + metadata_schema.key: metadata_schema.value for metadata_schema in self.run_metadata } From 0b66f072abd9c5700e6f755a254b91dce1c8945c Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Mon, 21 Oct 2024 08:47:06 +0000 Subject: [PATCH 012/124] Auto-update of E2E template --- examples/e2e/.copier-answers.yml | 2 +- examples/e2e/steps/deployment/deployment_deploy.py | 2 +- examples/e2e/steps/hp_tuning/hp_tuning_select_best_model.py | 2 +- examples/e2e/steps/promotion/promote_with_metric_compare.py | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/e2e/.copier-answers.yml b/examples/e2e/.copier-answers.yml index 74cc33d8594..b008b2c1e99 100644 --- a/examples/e2e/.copier-answers.yml +++ b/examples/e2e/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.10.10 +_commit: 2024.10.21 _src_path: gh:zenml-io/template-e2e-batch data_quality_checks: true email: info@zenml.io diff --git a/examples/e2e/steps/deployment/deployment_deploy.py b/examples/e2e/steps/deployment/deployment_deploy.py index 3fb0d879f3f..dad351e45be 100644 --- a/examples/e2e/steps/deployment/deployment_deploy.py +++ b/examples/e2e/steps/deployment/deployment_deploy.py @@ -67,7 +67,7 @@ def deployment_deploy() -> ( registry_model_name=model.name, registry_model_version=model.run_metadata[ "model_registry_version" - ].value, + ], replace_existing=True, ) else: diff --git a/examples/e2e/steps/hp_tuning/hp_tuning_select_best_model.py b/examples/e2e/steps/hp_tuning/hp_tuning_select_best_model.py index 7d5a6bc33ea..65e524ecd98 100644 --- a/examples/e2e/steps/hp_tuning/hp_tuning_select_best_model.py +++ b/examples/e2e/steps/hp_tuning/hp_tuning_select_best_model.py @@ -50,7 +50,7 @@ def hp_tuning_select_best_model( hp_output = model.get_data_artifact("hp_result") model_: ClassifierMixin = hp_output.load() # fetch metadata we attached earlier - metric = float(hp_output.run_metadata["metric"].value) + metric = float(hp_output.run_metadata["metric"]) if best_model is None or best_metric < metric: best_model = model_ ### YOUR CODE ENDS HERE ### diff --git a/examples/e2e/steps/promotion/promote_with_metric_compare.py b/examples/e2e/steps/promotion/promote_with_metric_compare.py index 038d219d32d..6bc580f47ba 100644 --- a/examples/e2e/steps/promotion/promote_with_metric_compare.py +++ b/examples/e2e/steps/promotion/promote_with_metric_compare.py @@ -92,14 +92,14 @@ def promote_with_metric_compare( # Promote in Model Registry latest_version_model_registry_number = latest_version.run_metadata[ "model_registry_version" - ].value + ] if current_version_number is None: current_version_model_registry_number = ( latest_version_model_registry_number ) else: current_version_model_registry_number = ( - current_version.run_metadata["model_registry_version"].value + current_version.run_metadata["model_registry_version"] ) promote_in_model_registry( latest_version=latest_version_model_registry_number, @@ -111,7 +111,7 @@ def promote_with_metric_compare( else: promoted_version = current_version.run_metadata[ "model_registry_version" - ].value + ] logger.info( f"Current model version in `{target_env}` is `{promoted_version}` registered in Model Registry" From 4b2434aa0e72057803234e07de6325d6dca0bab9 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Mon, 21 Oct 2024 08:50:05 +0000 Subject: [PATCH 013/124] Auto-update of NLP template --- examples/e2e_nlp/.copier-answers.yml | 2 +- examples/e2e_nlp/gradio/requirements.txt | 2 +- .../e2e_nlp/steps/promotion/promote_get_metrics.py | 12 ++++-------- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/examples/e2e_nlp/.copier-answers.yml b/examples/e2e_nlp/.copier-answers.yml index 3ca2ba198fe..e509aae2760 100644 --- a/examples/e2e_nlp/.copier-answers.yml +++ b/examples/e2e_nlp/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.09.23 +_commit: 2024.10.21 _src_path: gh:zenml-io/template-nlp accelerator: cpu cloud_of_choice: aws diff --git a/examples/e2e_nlp/gradio/requirements.txt b/examples/e2e_nlp/gradio/requirements.txt index 1bddfdfb85b..b53f1df9e62 100644 --- a/examples/e2e_nlp/gradio/requirements.txt +++ b/examples/e2e_nlp/gradio/requirements.txt @@ -9,4 +9,4 @@ pandas==1.5.3 session_info==1.0.0 scikit-learn==1.5.0 transformers==4.28.1 -IPython==7.34.0 \ No newline at end of file +IPython==8.10.0 \ No newline at end of file diff --git a/examples/e2e_nlp/steps/promotion/promote_get_metrics.py b/examples/e2e_nlp/steps/promotion/promote_get_metrics.py index 7f2951a5865..b24ac42245c 100644 --- a/examples/e2e_nlp/steps/promotion/promote_get_metrics.py +++ b/examples/e2e_nlp/steps/promotion/promote_get_metrics.py @@ -56,9 +56,7 @@ def promote_get_metrics() -> ( # Get current model version metric in current run model = get_step_context().model - current_metrics = ( - model.get_model_artifact("model").run_metadata["metrics"].value - ) + current_metrics = model.get_model_artifact("model").run_metadata["metrics"] logger.info(f"Current model version metrics are {current_metrics}") # Get latest saved model version metric in target environment @@ -72,11 +70,9 @@ def promote_get_metrics() -> ( except KeyError: latest_version = None if latest_version: - latest_metrics = ( - latest_version.get_model_artifact("model") - .run_metadata["metrics"] - .value - ) + latest_metrics = latest_version.get_model_artifact( + "model" + ).run_metadata["metrics"] logger.info(f"Latest model version metrics are {latest_metrics}") else: logger.info("No currently promoted model version found.") From 8f4af6e9ca948ac92a7e400f797d994928b1f05e Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Mon, 21 Oct 2024 11:41:51 +0200 Subject: [PATCH 014/124] Fixed tests, removed additional .value access --- .../component-guide/orchestrators/kubeflow.md | 2 +- .../orchestrators/sagemaker_orchestrator.py | 2 +- .../orchestrators/azureml_orchestrator.py | 2 +- src/zenml/lineage_graph/lineage_graph.py | 8 +- tests/integration/functional/test_client.py | 169 ++---------------- .../functional/zen_stores/test_zen_store.py | 2 +- 6 files changed, 24 insertions(+), 161 deletions(-) diff --git a/docs/book/component-guide/orchestrators/kubeflow.md b/docs/book/component-guide/orchestrators/kubeflow.md index 65adf45a4c5..f3830cf03c4 100644 --- a/docs/book/component-guide/orchestrators/kubeflow.md +++ b/docs/book/component-guide/orchestrators/kubeflow.md @@ -198,7 +198,7 @@ Kubeflow comes with its own UI that you can use to find further details about yo from zenml.client import Client pipeline_run = Client().get_pipeline_run("") -orchestrator_url = pipeline_run.run_metadata["orchestrator_url"].value +orchestrator_url = pipeline_run.run_metadata["orchestrator_url"] ``` #### Additional configuration diff --git a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py index 342092416c7..3d001e2e208 100644 --- a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +++ b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py @@ -560,7 +560,7 @@ def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus: # Fetch the status of the _PipelineExecution if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata: - run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID].value + run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID] elif run.orchestrator_run_id is not None: run_id = run.orchestrator_run_id else: diff --git a/src/zenml/integrations/azure/orchestrators/azureml_orchestrator.py b/src/zenml/integrations/azure/orchestrators/azureml_orchestrator.py index 1e0f68143ff..d0e2058ca1d 100644 --- a/src/zenml/integrations/azure/orchestrators/azureml_orchestrator.py +++ b/src/zenml/integrations/azure/orchestrators/azureml_orchestrator.py @@ -482,7 +482,7 @@ def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus: # Fetch the status of the PipelineJob if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata: - run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID].value + run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID] elif run.orchestrator_run_id is not None: run_id = run.orchestrator_run_id else: diff --git a/src/zenml/lineage_graph/lineage_graph.py b/src/zenml/lineage_graph/lineage_graph.py index 9dba12d304b..b175ef71283 100644 --- a/src/zenml/lineage_graph/lineage_graph.py +++ b/src/zenml/lineage_graph/lineage_graph.py @@ -188,8 +188,8 @@ def add_step_node( inputs={k: v.uri for k, v in step.inputs.items()}, outputs={k: v.uri for k, v in step.outputs.items()}, metadata=[ - (m.key, str(m.value), str(m.type)) - for m in step.run_metadata.values() + (k, v, str(type(v))) + for k,v in step.run_metadata.items() ], ), ) @@ -225,8 +225,8 @@ def add_artifact_node( producer_step_id=str(artifact.producer_step_run_id), uri=artifact.uri, metadata=[ - (m.key, str(m.value), str(m.type)) - for m in artifact.run_metadata.values() + (k, v, str(type(v))) + for k,v in artifact.run_metadata.items() ], ), ) diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index 23983031e40..ef8b61e0f89 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -482,190 +482,53 @@ def test_listing_pipelines(clean_client): def test_create_run_metadata_for_pipeline_run(clean_client_with_run: Client): """Test creating run metadata linked only to a pipeline run.""" pipeline_run = clean_client_with_run.list_pipeline_runs()[0] - existing_metadata = clean_client_with_run.list_run_metadata( - resource_id=pipeline_run.id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, - ) - # Assert that the created metadata is correct - new_metadata = clean_client_with_run.create_run_metadata( + clean_client_with_run.create_run_metadata( metadata={"axel": "is awesome"}, resource_id=pipeline_run.id, resource_type=MetadataResourceTypes.PIPELINE_RUN, ) - assert isinstance(new_metadata, list) - assert len(new_metadata) == 1 - assert new_metadata[0].key == "axel" - assert new_metadata[0].value == "is awesome" - assert new_metadata[0].type == MetadataTypeEnum.STRING - assert new_metadata[0].resource_id == pipeline_run.id - assert new_metadata[0].resource_type == MetadataResourceTypes.PIPELINE_RUN - assert new_metadata[0].stack_component_id is None - - # Assert new metadata is linked to the pipeline run - all_metadata = clean_client_with_run.list_run_metadata( - resource_id=pipeline_run.id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, - ) - assert len(all_metadata) == len(existing_metadata) + 1 - - -def test_create_run_metadata_for_pipeline_run_and_component( - clean_client_with_run: Client, -): - """Test creating metadata linked to a pipeline run and a stack component""" - pipeline_run = clean_client_with_run.list_pipeline_runs()[0] - orchestrator_id = clean_client_with_run.active_stack_model.components[ - "orchestrator" - ][0].id - existing_metadata = clean_client_with_run.list_run_metadata( - resource_id=pipeline_run.id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, - ) - existing_component_metadata = clean_client_with_run.list_run_metadata( - stack_component_id=orchestrator_id - ) - - # Assert that the created metadata is correct - new_metadata = clean_client_with_run.create_run_metadata( - metadata={"aria": "is awesome too"}, - resource_id=pipeline_run.id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, - stack_component_id=orchestrator_id, - ) - assert isinstance(new_metadata, list) - assert len(new_metadata) == 1 - assert new_metadata[0].key == "aria" - assert new_metadata[0].value == "is awesome too" - assert new_metadata[0].type == MetadataTypeEnum.STRING - assert new_metadata[0].resource_id == pipeline_run.id - assert new_metadata[0].resource_type == MetadataResourceTypes.PIPELINE_RUN - assert new_metadata[0].stack_component_id == orchestrator_id - - # Assert new metadata is linked to the pipeline run - registered_metadata = clean_client_with_run.list_run_metadata( - resource_id=pipeline_run.id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, - ) - assert len(registered_metadata) == len(existing_metadata) + 1 + rm = clean_client_with_run.get_pipeline_run(pipeline_run.id).run_metadata - # Assert new metadata is linked to the stack component - registered_component_metadata = clean_client_with_run.list_run_metadata( - stack_component_id=orchestrator_id - ) - assert ( - len(registered_component_metadata) - == len(existing_component_metadata) + 1 - ) + assert isinstance(rm, dict) + assert len(rm.values()) == 1 + assert rm["axel"] == "is_awesome" def test_create_run_metadata_for_step_run(clean_client_with_run: Client): """Test creating run metadata linked only to a step run.""" step_run = clean_client_with_run.list_run_steps()[0] - existing_metadata = clean_client_with_run.list_run_metadata( - resource_id=step_run.id, resource_type=MetadataResourceTypes.STEP_RUN - ) # Assert that the created metadata is correct - new_metadata = clean_client_with_run.create_run_metadata( + clean_client_with_run.create_run_metadata( metadata={"axel": "is awesome"}, resource_id=step_run.id, resource_type=MetadataResourceTypes.STEP_RUN, ) - assert isinstance(new_metadata, list) - assert len(new_metadata) == 1 - assert new_metadata[0].key == "axel" - assert new_metadata[0].value == "is awesome" - assert new_metadata[0].type == MetadataTypeEnum.STRING - assert new_metadata[0].resource_id == step_run.id - assert new_metadata[0].resource_type == MetadataResourceTypes.STEP_RUN - assert new_metadata[0].stack_component_id is None - - # Assert new metadata is linked to the step run - registered_metadata = clean_client_with_run.list_run_metadata( - resource_id=step_run.id, resource_type=MetadataResourceTypes.STEP_RUN - ) - assert len(registered_metadata) == len(existing_metadata) + 1 - - -def test_create_run_metadata_for_step_run_and_component( - clean_client_with_run: Client, -): - """Test creating metadata linked to a step run and a stack component""" - step_run = clean_client_with_run.list_run_steps()[0] - orchestrator_id = clean_client_with_run.active_stack_model.components[ - "orchestrator" - ][0].id - existing_metadata = clean_client_with_run.list_run_metadata( - resource_id=step_run.id, resource_type=MetadataResourceTypes.STEP_RUN - ) - existing_component_metadata = clean_client_with_run.list_run_metadata( - stack_component_id=orchestrator_id - ) + rm = clean_client_with_run.get_run_step(step_run.id).run_metadata - # Assert that the created metadata is correct - new_metadata = clean_client_with_run.create_run_metadata( - metadata={"aria": "is awesome too"}, - resource_id=step_run.id, - resource_type=MetadataResourceTypes.STEP_RUN, - stack_component_id=orchestrator_id, - ) - assert isinstance(new_metadata, list) - assert len(new_metadata) == 1 - assert new_metadata[0].key == "aria" - assert new_metadata[0].value == "is awesome too" - assert new_metadata[0].type == MetadataTypeEnum.STRING - assert new_metadata[0].resource_id == step_run.id - assert new_metadata[0].resource_type == MetadataResourceTypes.STEP_RUN - assert new_metadata[0].stack_component_id == orchestrator_id - - # Assert new metadata is linked to the step run - registered_metadata = clean_client_with_run.list_run_metadata( - resource_id=step_run.id, resource_type=MetadataResourceTypes.STEP_RUN - ) - assert len(registered_metadata) == len(existing_metadata) + 1 + assert isinstance(rm, dict) + assert len(rm.values()) == 1 + assert rm["axel"] == "is_awesome" - # Assert new metadata is linked to the stack component - registered_component_metadata = clean_client_with_run.list_run_metadata( - stack_component_id=orchestrator_id - ) - assert ( - len(registered_component_metadata) - == len(existing_component_metadata) + 1 - ) def test_create_run_metadata_for_artifact(clean_client_with_run: Client): """Test creating run metadata linked to an artifact.""" artifact_version = clean_client_with_run.list_artifact_versions()[0] - existing_metadata = clean_client_with_run.list_run_metadata( - resource_id=artifact_version.id, - resource_type=MetadataResourceTypes.ARTIFACT_VERSION, - ) # Assert that the created metadata is correct - new_metadata = clean_client_with_run.create_run_metadata( + clean_client_with_run.create_run_metadata( metadata={"axel": "is awesome"}, resource_id=artifact_version.id, resource_type=MetadataResourceTypes.ARTIFACT_VERSION, ) - assert isinstance(new_metadata, list) - assert len(new_metadata) == 1 - assert new_metadata[0].key == "axel" - assert new_metadata[0].value == "is awesome" - assert new_metadata[0].type == MetadataTypeEnum.STRING - assert new_metadata[0].resource_id == artifact_version.id - assert ( - new_metadata[0].resource_type == MetadataResourceTypes.ARTIFACT_VERSION - ) - assert new_metadata[0].stack_component_id is None - # Assert new metadata is linked to the artifact - registered_metadata = clean_client_with_run.list_run_metadata( - resource_id=artifact_version.id, - resource_type=MetadataResourceTypes.ARTIFACT_VERSION, - ) - assert len(registered_metadata) == len(existing_metadata) + 1 + rm = clean_client_with_run.get_artifact_version(artifact_version.id).run_metadata + + assert isinstance(rm, dict) + assert len(rm.values()) == 1 + assert rm["axel"] == "is_awesome" # .---------. diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 2d6f707e630..a6c0a40bbf6 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -5435,7 +5435,7 @@ def test_metadata_full_cycle_with_cascade_deletion( pr if type_ == MetadataResourceTypes.PIPELINE_RUN else sr ) - rm = client.zen_store.create_run_metadata( + client.zen_store.create_run_metadata( RunMetadataRequest( user=client.active_user.id, workspace=client.active_workspace.id, From edba62579c7c5e3e8fed7d77574f993db80cb773 Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Mon, 21 Oct 2024 14:29:50 +0200 Subject: [PATCH 015/124] Further fixing --- src/zenml/lineage_graph/lineage_graph.py | 4 ++-- .../zen_stores/schemas/step_run_schemas.py | 2 +- tests/integration/functional/test_client.py | 13 ++++++------- .../functional/zen_stores/test_zen_store.py | 17 +++++------------ 4 files changed, 14 insertions(+), 22 deletions(-) diff --git a/src/zenml/lineage_graph/lineage_graph.py b/src/zenml/lineage_graph/lineage_graph.py index b175ef71283..dab253bad51 100644 --- a/src/zenml/lineage_graph/lineage_graph.py +++ b/src/zenml/lineage_graph/lineage_graph.py @@ -189,7 +189,7 @@ def add_step_node( outputs={k: v.uri for k, v in step.outputs.items()}, metadata=[ (k, v, str(type(v))) - for k,v in step.run_metadata.items() + for k, v in step.run_metadata.items() ], ), ) @@ -226,7 +226,7 @@ def add_artifact_node( uri=artifact.uri, metadata=[ (k, v, str(type(v))) - for k,v in artifact.run_metadata.items() + for k, v in artifact.run_metadata.items() ], ), ) diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index 7917fd01d9d..63ef4f7e9cf 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -214,7 +214,7 @@ def to_model( or a step_configuration. """ run_metadata = { - metadata_schema.key: metadata_schema.value + metadata_schema.key: json.loads(metadata_schema.value) for metadata_schema in self.run_metadata } diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index ef8b61e0f89..53d6c891681 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -55,7 +55,6 @@ StackExistsError, ) from zenml.io import fileio -from zenml.metadata.metadata_types import MetadataTypeEnum from zenml.model.model import Model from zenml.models import ( ComponentResponse, @@ -492,7 +491,7 @@ def test_create_run_metadata_for_pipeline_run(clean_client_with_run: Client): assert isinstance(rm, dict) assert len(rm.values()) == 1 - assert rm["axel"] == "is_awesome" + assert rm["axel"] == "is awesome" def test_create_run_metadata_for_step_run(clean_client_with_run: Client): @@ -509,8 +508,7 @@ def test_create_run_metadata_for_step_run(clean_client_with_run: Client): assert isinstance(rm, dict) assert len(rm.values()) == 1 - assert rm["axel"] == "is_awesome" - + assert rm["axel"] == "is awesome" def test_create_run_metadata_for_artifact(clean_client_with_run: Client): @@ -524,11 +522,12 @@ def test_create_run_metadata_for_artifact(clean_client_with_run: Client): resource_type=MetadataResourceTypes.ARTIFACT_VERSION, ) - rm = clean_client_with_run.get_artifact_version(artifact_version.id).run_metadata + rm = clean_client_with_run.get_artifact_version( + artifact_version.id + ).run_metadata assert isinstance(rm, dict) - assert len(rm.values()) == 1 - assert rm["axel"] == "is_awesome" + assert rm["axel"] == "is awesome" # .---------. diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index a6c0a40bbf6..b0153f263cf 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -5450,19 +5450,12 @@ def test_metadata_full_cycle_with_cascade_deletion( ) ) if type_ == MetadataResourceTypes.PIPELINE_RUN: - rm = client.zen_store.get_run(resource.id, True).metadata - assert rm.key == "foo" - assert rm.value == "bar" - assert rm.resource_id == resource.id - assert rm.resource_type == type_ - assert rm.type == MetadataTypeEnum.STRING + rm = client.zen_store.get_run(resource.id, True).run_metadata + assert rm["foo"] == "bar" + elif type_ == MetadataResourceTypes.STEP_RUN: - rm = client.zen_store.get_run_step(resource.id, True).metadata - assert rm.key == "foo" - assert rm.value == "bar" - assert rm.resource_id == resource.id - assert rm.resource_type == type_ - assert rm.type == MetadataTypeEnum.STRING + rm = client.zen_store.get_run_step(resource.id, True).run_metadata + assert rm["foo"] == "bar" if type_ == MetadataResourceTypes.ARTIFACT_VERSION: client.zen_store.delete_artifact_version(resource.id) From c2b6955ccf365810dca702753b114f508202657c Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Mon, 21 Oct 2024 16:10:45 +0200 Subject: [PATCH 016/124] Fixed linting issues --- .../integrations/gcp/orchestrators/vertex_orchestrator.py | 2 +- src/zenml/models/v2/core/run_metadata.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py b/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py index 2c02bb71b8a..bb218febb84 100644 --- a/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +++ b/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py @@ -835,7 +835,7 @@ def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus: # Fetch the status of the PipelineJob if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata: - run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID].value + run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID] elif run.orchestrator_run_id is not None: run_id = run.orchestrator_run_id else: diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index 366f12edbb1..bed5b56a153 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -16,7 +16,7 @@ from typing import Dict, Optional from uuid import UUID -from pydantic import Field +from pydantic import Field, BaseModel from zenml.enums import MetadataResourceTypes from zenml.metadata.metadata_types import MetadataType, MetadataTypeEnum @@ -48,7 +48,7 @@ class RunMetadataRequest(WorkspaceScopedRequest): ) -class LazyRunMetadataResponse(WorkspaceScopedResponse): +class LazyRunMetadataResponse(BaseModel): """Lazy run metadata response. Used if the run metadata is accessed from the model in @@ -62,7 +62,7 @@ class LazyRunMetadataResponse(WorkspaceScopedResponse): lazy_load_model_name: str lazy_load_model_version: Optional[str] = None - def get_body(self) -> None: # type: ignore[override] + def get_body(self) -> None: """Protects from misuse of the lazy loader. Raises: @@ -72,7 +72,7 @@ def get_body(self) -> None: # type: ignore[override] "Cannot access run metadata body before pipeline runs." ) - def get_metadata(self) -> None: # type: ignore[override] + def get_metadata(self) -> None: """Protects from misuse of the lazy loader. Raises: From 8f6d305a4de6fb67b571fba6f66a2a5fa268a2bb Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Tue, 22 Oct 2024 09:20:04 +0200 Subject: [PATCH 017/124] Reformatted --- src/zenml/models/v2/core/run_metadata.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index bed5b56a153..333849633cd 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -16,13 +16,12 @@ from typing import Dict, Optional from uuid import UUID -from pydantic import Field, BaseModel +from pydantic import BaseModel, Field from zenml.enums import MetadataResourceTypes from zenml.metadata.metadata_types import MetadataType, MetadataTypeEnum from zenml.models.v2.base.scoped import ( WorkspaceScopedRequest, - WorkspaceScopedResponse, ) # ------------------ Request Model ------------------ From 6b183220111e3972975b7d25eaf1e432799de358 Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Tue, 22 Oct 2024 12:40:58 +0200 Subject: [PATCH 018/124] Linted, formatted and tested again --- src/zenml/metadata/lazy_load.py | 7 +++---- src/zenml/models/__init__.py | 1 - src/zenml/models/v2/core/run_metadata.py | 25 ++---------------------- 3 files changed, 5 insertions(+), 28 deletions(-) diff --git a/src/zenml/metadata/lazy_load.py b/src/zenml/metadata/lazy_load.py index 4064450142a..8c7d37b487f 100644 --- a/src/zenml/metadata/lazy_load.py +++ b/src/zenml/metadata/lazy_load.py @@ -13,10 +13,9 @@ # permissions and limitations under the License. """Run Metadata Lazy Loader definition.""" -from typing import TYPE_CHECKING, Optional +from typing import Optional -if TYPE_CHECKING: - from zenml.models import RunMetadataResponse +from zenml.metadata.metadata_types import MetadataType class RunMetadataLazyGetter: @@ -47,7 +46,7 @@ def __init__( self._lazy_load_artifact_name = _lazy_load_artifact_name self._lazy_load_artifact_version = _lazy_load_artifact_version - def __getitem__(self, key: str) -> "RunMetadataResponse": + def __getitem__(self, key: str) -> MetadataType: """Get the metadata for the given key. Args: diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index a330565c7bf..887221cd947 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -412,7 +412,6 @@ FlavorResponseBody.model_rebuild() FlavorResponseMetadata.model_rebuild() LazyArtifactVersionResponse.model_rebuild() -LazyRunMetadataResponse.model_rebuild() ModelResponseBody.model_rebuild() ModelResponseMetadata.model_rebuild() ModelVersionResponseBody.model_rebuild() diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index 333849633cd..87c0da9d2f4 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -16,7 +16,7 @@ from typing import Dict, Optional from uuid import UUID -from pydantic import BaseModel, Field +from pydantic import Field from zenml.enums import MetadataResourceTypes from zenml.metadata.metadata_types import MetadataType, MetadataTypeEnum @@ -47,36 +47,15 @@ class RunMetadataRequest(WorkspaceScopedRequest): ) -class LazyRunMetadataResponse(BaseModel): +class LazyRunMetadataResponse(dict): """Lazy run metadata response. Used if the run metadata is accessed from the model in a pipeline context available only during pipeline compilation. """ - id: Optional[UUID] = None # type: ignore[assignment] lazy_load_artifact_name: Optional[str] = None lazy_load_artifact_version: Optional[str] = None lazy_load_metadata_name: Optional[str] = None lazy_load_model_name: str lazy_load_model_version: Optional[str] = None - - def get_body(self) -> None: - """Protects from misuse of the lazy loader. - - Raises: - RuntimeError: always - """ - raise RuntimeError( - "Cannot access run metadata body before pipeline runs." - ) - - def get_metadata(self) -> None: - """Protects from misuse of the lazy loader. - - Raises: - RuntimeError: always - """ - raise RuntimeError( - "Cannot access run metadata metadata before pipeline runs." - ) From 8b3a1bd7264a1cb9905397d6c8543cd5a59978ea Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Tue, 22 Oct 2024 13:06:35 +0200 Subject: [PATCH 019/124] Typing --- src/zenml/models/v2/core/run_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index 87c0da9d2f4..f58225710c5 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -47,7 +47,7 @@ class RunMetadataRequest(WorkspaceScopedRequest): ) -class LazyRunMetadataResponse(dict): +class LazyRunMetadataResponse(dict[str, MetadataType]): """Lazy run metadata response. Used if the run metadata is accessed from the model in From 5cc7b4474869d98f0a889bdc019874d647fc4d3a Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 28 Oct 2024 17:15:55 +0100 Subject: [PATCH 020/124] Maybe fix everything --- src/zenml/metadata/lazy_load.py | 20 +++++++++++++++++--- src/zenml/models/__init__.py | 1 - src/zenml/models/v2/core/run_metadata.py | 14 -------------- src/zenml/steps/base_step.py | 2 +- src/zenml/steps/entrypoint_function_utils.py | 2 +- 5 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/zenml/metadata/lazy_load.py b/src/zenml/metadata/lazy_load.py index 8c7d37b487f..7ce2cc0d30a 100644 --- a/src/zenml/metadata/lazy_load.py +++ b/src/zenml/metadata/lazy_load.py @@ -15,9 +15,25 @@ from typing import Optional +from pydantic import BaseModel + from zenml.metadata.metadata_types import MetadataType +class LazyRunMetadataResponse(BaseModel): + """Lazy run metadata response. + + Used if the run metadata is accessed from the model in + a pipeline context available only during pipeline compilation. + """ + + lazy_load_artifact_name: Optional[str] = None + lazy_load_artifact_version: Optional[str] = None + lazy_load_metadata_name: Optional[str] = None + lazy_load_model_name: str + lazy_load_model_version: Optional[str] = None + + class RunMetadataLazyGetter: """Run Metadata Lazy Getter helper class. @@ -55,9 +71,7 @@ def __getitem__(self, key: str) -> MetadataType: Returns: The metadata lazy loader wrapper for the given key. """ - from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse - - return LazyRunMetadataResponse( + return LazyRunMetadataResponse( # type: ignore[return-value] lazy_load_model_name=self._lazy_load_model_name, lazy_load_model_version=self._lazy_load_model_version, lazy_load_artifact_name=self._lazy_load_artifact_name, diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index 887221cd947..042ed8a1185 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -238,7 +238,6 @@ ) from zenml.models.v2.base.base_plugin_flavor import BasePluginFlavorResponse from zenml.models.v2.core.run_metadata import ( - LazyRunMetadataResponse, RunMetadataRequest, ) from zenml.models.v2.core.schedule import ( diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index f58225710c5..c4a2ef8e678 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -45,17 +45,3 @@ class RunMetadataRequest(WorkspaceScopedRequest): types: Dict[str, "MetadataTypeEnum"] = Field( title="The types of the metadata to be created.", ) - - -class LazyRunMetadataResponse(dict[str, MetadataType]): - """Lazy run metadata response. - - Used if the run metadata is accessed from the model in - a pipeline context available only during pipeline compilation. - """ - - lazy_load_artifact_name: Optional[str] = None - lazy_load_artifact_version: Optional[str] = None - lazy_load_metadata_name: Optional[str] = None - lazy_load_model_name: str - lazy_load_model_version: Optional[str] = None diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index 2a3d324608c..ddb8cacee61 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -326,11 +326,11 @@ def _parse_call_args( The artifacts, external artifacts, model version artifacts/metadata and parameters for the step. """ from zenml.artifacts.external_artifact import ExternalArtifact + from zenml.metadata.lazy_load import LazyRunMetadataResponse from zenml.model.lazy_load import ModelVersionDataLazyLoader from zenml.models.v2.core.artifact_version import ( LazyArtifactVersionResponse, ) - from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse signature = inspect.signature(self.entrypoint, follow_wrapped=True) diff --git a/src/zenml/steps/entrypoint_function_utils.py b/src/zenml/steps/entrypoint_function_utils.py index 997d639d8ee..9f87ea826b7 100644 --- a/src/zenml/steps/entrypoint_function_utils.py +++ b/src/zenml/steps/entrypoint_function_utils.py @@ -32,7 +32,7 @@ from zenml.exceptions import StepInterfaceError from zenml.logger import get_logger from zenml.materializers.base_materializer import BaseMaterializer -from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse +from zenml.metadata.lazy_load import LazyRunMetadataResponse from zenml.steps.utils import ( OutputSignature, parse_return_type_annotations, From c368dec825c45b01aa23c68410b13e0b834af7fa Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 28 Oct 2024 17:23:24 +0100 Subject: [PATCH 021/124] Apply some feedback --- src/zenml/cli/model.py | 5 +-- src/zenml/models/v2/core/model_version.py | 2 +- .../schemas/run_metadata_schemas.py | 40 ------------------- 3 files changed, 2 insertions(+), 45 deletions(-) diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index d8ea4534bfd..403d4a7b42a 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -59,9 +59,6 @@ def _model_to_print(model: ModelResponse) -> Dict[str, Any]: def _model_version_to_print( model_version: ModelVersionResponse, ) -> Dict[str, Any]: - run_metadata = None - if model_version.run_metadata: - run_metadata = {k: v for k, v in model_version.run_metadata.items()} return { "id": model_version.id, "model": model_version.model.name, @@ -69,7 +66,7 @@ def _model_version_to_print( "number": model_version.number, "description": model_version.description, "stage": model_version.stage, - "run_metadata": run_metadata, + "run_metadata": model_version.run_metadata, "tags": [t.name for t in model_version.tags], "data_artifacts_count": len(model_version.data_artifact_ids), "model_artifacts_count": len(model_version.model_artifact_ids), diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index 88028900518..935f4d3e30a 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -302,7 +302,7 @@ def description(self) -> Optional[str]: return self.get_metadata().description @property - def run_metadata(self) -> Optional[Dict[str, MetadataType]]: + def run_metadata(self) -> Dict[str, MetadataType]: """The `run_metadata` property. Returns: diff --git a/src/zenml/zen_stores/schemas/run_metadata_schemas.py b/src/zenml/zen_stores/schemas/run_metadata_schemas.py index 4de528abd5a..18d203111c7 100644 --- a/src/zenml/zen_stores/schemas/run_metadata_schemas.py +++ b/src/zenml/zen_stores/schemas/run_metadata_schemas.py @@ -103,43 +103,3 @@ class RunMetadataSchema(BaseSchema, table=True): key: str value: str = Field(sa_column=Column(TEXT, nullable=False)) type: str - - # def to_model( - # self, - # include_metadata: bool = False, - # include_resources: bool = False, - # **kwargs: Any, - # ) -> "RunMetadataResponse": - # """Convert a `RunMetadataSchema` to a `RunMetadataResponse`. - # - # Args: - # include_metadata: Whether the metadata will be filled. - # include_resources: Whether the resources will be filled. - # **kwargs: Keyword arguments to allow schema specific logic - # - # - # Returns: - # The created `RunMetadataResponse`. - # """ - # body = RunMetadataResponseBody( - # user=self.user.to_model() if self.user else None, - # key=self.key, - # created=self.created, - # updated=self.updated, - # value=json.loads(self.value), - # type=MetadataTypeEnum(self.type), - # ) - # metadata = None - # if include_metadata: - # metadata = RunMetadataResponseMetadata( - # workspace=self.workspace.to_model(), - # resource_id=self.resource_id, - # resource_type=MetadataResourceTypes(self.resource_type), - # stack_component_id=self.stack_component_id, - # ) - # - # return RunMetadataResponse( - # id=self.id, - # body=body, - # metadata=metadata, - # ) From 74c1a425e97a4a6063f2f52a4130eae4020c6b96 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 5 Nov 2024 23:27:36 +0100 Subject: [PATCH 022/124] new operation --- src/zenml/enums.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/zenml/enums.py b/src/zenml/enums.py index 1083c8390fe..96be55287e9 100644 --- a/src/zenml/enums.py +++ b/src/zenml/enums.py @@ -245,6 +245,7 @@ class GenericFilterOps(StrEnum): CONTAINS = "contains" STARTSWITH = "startswith" ENDSWITH = "endswith" + ONEOF = "oneof" GTE = "gte" GT = "gt" LTE = "lte" From 53dc8e8a17507155ec98ac2b94a1415a410a01ad Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 5 Nov 2024 23:29:46 +0100 Subject: [PATCH 023/124] new log_metadata function --- src/zenml/__init__.py | 1 + src/zenml/pipelines/utils.py | 66 ++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) create mode 100644 src/zenml/pipelines/utils.py diff --git a/src/zenml/__init__.py b/src/zenml/__init__.py index 299a090d189..9821f414104 100644 --- a/src/zenml/__init__.py +++ b/src/zenml/__init__.py @@ -48,6 +48,7 @@ from zenml.pipelines import get_pipeline_context, pipeline from zenml.steps import step, get_step_context from zenml.steps.utils import log_step_metadata +from zenml.pipelines.utils import log_metadata from zenml.entrypoints import entrypoint __all__ = [ diff --git a/src/zenml/pipelines/utils.py b/src/zenml/pipelines/utils.py new file mode 100644 index 00000000000..b56224e80ac --- /dev/null +++ b/src/zenml/pipelines/utils.py @@ -0,0 +1,66 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Utility functions to run ZenML pipelines.""" + +import contextlib +from typing import Dict, Optional, Union +from uuid import UUID + +from zenml.client import Client +from zenml.enums import MetadataResourceTypes +from zenml.metadata.metadata_types import MetadataType +from zenml.steps.step_context import get_step_context + + +def log_metadata( + metadata: Dict[str, MetadataType], + run_name_id_or_prefix: Optional[Union[str, UUID]] = None, +) -> None: + """Logs metadata. + + Args: + metadata: The metadata to log. + run_name_id_or_prefix: The name, ID or prefix of the run to log metadata + for. Can be omitted when being called inside a step. + + Raises: + ValueError: If no run identifier is provided and the function is not + called from within a step. + """ + step_context = None + if not run_name_id_or_prefix: + with contextlib.suppress(RuntimeError): + step_context = get_step_context() + run_name_id_or_prefix = step_context.pipeline_run.id + + if not run_name_id_or_prefix: + raise ValueError( + "No pipeline name or ID provided and you are not running " + "within a step. Please provide a pipeline name or ID, or " + "provide a run ID." + ) + + client = Client() + if step_context is None and not isinstance(run_name_id_or_prefix, UUID): + run_name_id_or_prefix = client.get_pipeline_run( + name_id_or_prefix=run_name_id_or_prefix, + ).id + + # TODO: Should we also create the corresponding step and model metadata + # as well? + client.create_run_metadata( + metadata=metadata, + resource_id=run_name_id_or_prefix, + resource_type=MetadataResourceTypes.PIPELINE_RUN, + ) From 68a455c639ab748414a957cab81c152882467804 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 5 Nov 2024 23:36:51 +0100 Subject: [PATCH 024/124] changes to the base filters --- src/zenml/models/v2/base/filter.py | 84 +++++++++++++++++++++++++++--- 1 file changed, 77 insertions(+), 7 deletions(-) diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index 486226a3a5f..25e010b3db4 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -13,6 +13,7 @@ # permissions and limitations under the License. """Base filter model definitions.""" +import json from abc import ABC, abstractmethod from datetime import datetime from typing import ( @@ -36,7 +37,7 @@ field_validator, model_validator, ) -from sqlalchemy import asc, desc +from sqlalchemy import asc, desc, cast, Float from sqlmodel import SQLModel from zenml.constants import ( @@ -171,6 +172,11 @@ class StrFilter(Filter): GenericFilterOps.STARTSWITH, GenericFilterOps.CONTAINS, GenericFilterOps.ENDSWITH, + GenericFilterOps.ONEOF, + GenericFilterOps.GT, + GenericFilterOps.GTE, + GenericFilterOps.LT, + GenericFilterOps.LTE, ] def generate_query_conditions_from_column(self, column: Any) -> Any: @@ -190,6 +196,32 @@ def generate_query_conditions_from_column(self, column: Any) -> Any: return column.endswith(f"{self.value}") if self.operation == GenericFilterOps.NOT_EQUALS: return column != self.value + if self.operation == GenericFilterOps.ONEOF: + return column.in_(self.value) + if self.operation in { + GenericFilterOps.GT, + GenericFilterOps.LT, + GenericFilterOps.GTE, + GenericFilterOps.LTE + }: + # Try to cast the column to a numeric type for numeric operations + try: + numeric_column = cast(column, Float) + if self.operation == GenericFilterOps.GT: + return numeric_column > self.value + if self.operation == GenericFilterOps.LT: + return numeric_column < self.value + if self.operation == GenericFilterOps.GTE: + return numeric_column >= self.value + if self.operation == GenericFilterOps.LTE: + return numeric_column <= self.value + except Exception: + # Handle the exception or fallback as needed + raise ValueError( + "Failed to cast column to numeric type for comparison" + ) + else: + raise ValueError("Invalid operation or incompatible data type") return column == self.value @@ -598,6 +630,13 @@ def _resolve_operator(value: Any) -> Tuple[Any, GenericFilterOps]: ): value = split_value[1] operator = GenericFilterOps(split_value[0]) + + if operator == operator.ONEOF: + try: + value = json.loads(value) + except: + raise ValueError("Add some error message here....") + return value, operator def generate_name_or_id_query_conditions( @@ -833,16 +872,17 @@ def define_filter( # Create str filters if self.is_str_field(column): - return StrFilter( - operation=GenericFilterOps(operator), + return self._define_str_filter( + operator=GenericFilterOps(operator), column=column, value=value, ) # Handle unsupported datatypes logger.warning( - f"The Datatype {self._model_class.model_fields[column].annotation} might " - "not be supported for filtering. Defaulting to a string filter." + f"The Datatype {self._model_class.model_fields[column].annotation} " + "might not be supported for filtering. Defaulting to a string " + "filter." ) return StrFilter( operation=GenericFilterOps(operator), @@ -1032,8 +1072,9 @@ def _define_uuid_filter( "Invalid value passed as UUID query parameter." ) from e - # Cast the value to string for further comparisons. - value = str(value) + # For equality checks, ensure that the value is a valid UUID. + if operator == GenericFilterOps.ONEOF and not isinstance(value, list): + raise ValueError("") # Generate the filter. uuid_filter = UUIDFilter( @@ -1043,6 +1084,35 @@ def _define_uuid_filter( ) return uuid_filter + @staticmethod + def _define_str_filter( + column: str, value: Any, operator: GenericFilterOps + ) -> StrFilter: + """Define a str filter for a given column. + + Args: + column: The column to filter on. + value: The UUID value by which to filter. + operator: The operator to use for filtering. + + Returns: + A Filter object. + + Raises: + ValueError: If the value is not a proper value. + """ + # For equality checks, ensure that the value is a valid UUID. + if operator == GenericFilterOps.ONEOF and not isinstance(value, list): + raise ValueError("") + + # Generate the filter. + str_filter = StrFilter( + operation=GenericFilterOps(operator), + column=column, + value=value, + ) + return str_filter + @staticmethod def _define_bool_filter( column: str, value: Any, operator: GenericFilterOps From 4af4165eed571a07712472b0e168a9d96af6c235 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 6 Nov 2024 11:10:17 +0100 Subject: [PATCH 025/124] new filters --- src/zenml/models/v2/core/artifact_version.py | 22 ++++++++++++++++++++ src/zenml/models/v2/core/model_version.py | 22 ++++++++++++++++++++ src/zenml/models/v2/core/pipeline_run.py | 22 +++++++++++++++++++- src/zenml/models/v2/core/step_run.py | 21 +++++++++++++++++++ 4 files changed, 86 insertions(+), 1 deletion(-) diff --git a/src/zenml/models/v2/core/artifact_version.py b/src/zenml/models/v2/core/artifact_version.py index 6ee959e4909..9345970cfd4 100644 --- a/src/zenml/models/v2/core/artifact_version.py +++ b/src/zenml/models/v2/core/artifact_version.py @@ -459,6 +459,7 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter): "user", "model", "pipeline_run", + "run_metadata", ] artifact_id: Optional[Union[UUID, str]] = Field( default=None, @@ -530,6 +531,10 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter): description="Name/ID of a pipeline run that is associated with this " "artifact version.", ) + run_metadata: Optional[Dict[str, str]] = Field( + default=None, + description="The run_metadata to filter the artifact versions by." + ) model_config = ConfigDict(protected_namespaces=()) @@ -549,6 +554,7 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: ModelSchema, ModelVersionArtifactSchema, PipelineRunSchema, + RunMetadataSchema, StepRunInputArtifactSchema, StepRunOutputArtifactSchema, StepRunSchema, @@ -630,6 +636,22 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: ) custom_filters.append(pipeline_run_filter) + if self.run_metadata is not None: + from zenml.enums import MetadataResourceTypes + + for key, value in self.run_metadata.items(): + additional_filter = and_( + RunMetadataSchema.resource_id == ArtifactVersionSchema.id, + RunMetadataSchema.resource_type == MetadataResourceTypes.ARTIFACT_VERSION, + RunMetadataSchema.key == key, + self.generate_custom_query_conditions_for_column( + value=value, + table=RunMetadataSchema, + column="value", + ), + ) + custom_filters.append(additional_filter) + return custom_filters diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index dbc0a0f214f..72af7516544 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -590,6 +590,7 @@ class ModelVersionFilter(WorkspaceScopedTaggableFilter): FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS, "user", + "run_metadata", ] name: Optional[str] = Field( @@ -619,6 +620,10 @@ class ModelVersionFilter(WorkspaceScopedTaggableFilter): default=None, description="Name/ID of the user that created the model version.", ) + run_metadata: Optional[Dict[str, str]] = Field( + default=None, + description="The run_metadata to filter the model versions by." + ) _model_id: UUID = PrivateAttr(None) @@ -652,6 +657,7 @@ def get_custom_filters( from zenml.zen_stores.schemas import ( ModelVersionSchema, UserSchema, + RunMetadataSchema, ) if self.user: @@ -665,6 +671,22 @@ def get_custom_filters( ) custom_filters.append(user_filter) + if self.run_metadata is not None: + from zenml.enums import MetadataResourceTypes + + for key, value in self.run_metadata.items(): + additional_filter = and_( + RunMetadataSchema.resource_id == ModelVersionSchema.id, + RunMetadataSchema.resource_type == MetadataResourceTypes.MODEL_VERSION, + RunMetadataSchema.key == key, + self.generate_custom_query_conditions_for_column( + value=value, + table=RunMetadataSchema, + column="value", + ), + ) + custom_filters.append(additional_filter) + return custom_filters def apply_filter( diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index 26f517acdd3..6b1764a8e60 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -587,6 +587,7 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): "stack_component", "pipeline_name", "templatable", + "run_metadata", ] name: Optional[str] = Field( default=None, @@ -665,6 +666,10 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): default=None, description="Name/ID of the user that created the run.", ) + run_metadata: Optional[Dict[str, str]] = Field( + default=None, + description="The run_metadata to filter the pipeline runs by." + ) # TODO: Remove once frontend is ready for it. This is replaced by the more # generic `pipeline` filter below. pipeline_name: Optional[str] = Field( @@ -694,7 +699,6 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): templatable: Optional[bool] = Field( default=None, description="Whether the run is templatable." ) - model_config = ConfigDict(protected_namespaces=()) def get_custom_filters( @@ -718,6 +722,7 @@ def get_custom_filters( PipelineDeploymentSchema, PipelineRunSchema, PipelineSchema, + RunMetadataSchema, ScheduleSchema, StackComponentSchema, StackCompositionSchema, @@ -887,5 +892,20 @@ def get_custom_filters( ) custom_filters.append(templatable_filter) + if self.run_metadata is not None: + from zenml.enums import MetadataResourceTypes + + for key, value in self.run_metadata.items(): + additional_filter = and_( + RunMetadataSchema.resource_id == PipelineRunSchema.id, + RunMetadataSchema.resource_type == MetadataResourceTypes.PIPELINE_RUN, + RunMetadataSchema.key == key, + self.generate_custom_query_conditions_for_column( + value=value, + table=RunMetadataSchema, + column="value", + ), + ) + custom_filters.append(additional_filter) return custom_filters diff --git a/src/zenml/models/v2/core/step_run.py b/src/zenml/models/v2/core/step_run.py index 1e1d2cac98a..cd037a01638 100644 --- a/src/zenml/models/v2/core/step_run.py +++ b/src/zenml/models/v2/core/step_run.py @@ -491,6 +491,7 @@ class StepRunFilter(WorkspaceScopedFilter): FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, "model", + "run_metadata", ] name: Optional[str] = Field( @@ -553,6 +554,10 @@ class StepRunFilter(WorkspaceScopedFilter): default=None, description="Name/ID of the model associated with the step run.", ) + run_metadata: Optional[Dict[str, str]] = Field( + default=None, + description="The run_metadata to filter the step runs by." + ) model_config = ConfigDict(protected_namespaces=()) @@ -572,6 +577,7 @@ def get_custom_filters( ModelSchema, ModelVersionSchema, StepRunSchema, + RunMetadataSchema, ) if self.model: @@ -583,5 +589,20 @@ def get_custom_filters( ), ) custom_filters.append(model_filter) + if self.run_metadata is not None: + from zenml.enums import MetadataResourceTypes + + for key, value in self.run_metadata.items(): + additional_filter = and_( + RunMetadataSchema.resource_id == StepRunSchema.id, + RunMetadataSchema.resource_type == MetadataResourceTypes.STEP_RUN, + RunMetadataSchema.key == key, + self.generate_custom_query_conditions_for_column( + value=value, + table=RunMetadataSchema, + column="value", + ), + ) + custom_filters.append(additional_filter) return custom_filters From fdf89457578f7e5fa1571369f276b670bb8fa1ce Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 6 Nov 2024 11:10:40 +0100 Subject: [PATCH 026/124] adding log_metadata to __all__ --- src/zenml/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/zenml/__init__.py b/src/zenml/__init__.py index 9821f414104..2fbd2231140 100644 --- a/src/zenml/__init__.py +++ b/src/zenml/__init__.py @@ -60,6 +60,7 @@ "log_artifact_metadata", "log_model_metadata", "log_step_metadata", + "log_metadata", "Model", "link_artifact_to_model", "pipeline", From 39f5bf8c7ebef8a7003383e89e3bfdd41b45935b Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 6 Nov 2024 11:16:49 +0100 Subject: [PATCH 027/124] checkpoint with float casting --- src/zenml/models/v2/base/filter.py | 36 +++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index 25e010b3db4..af1128e8047 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -207,21 +207,21 @@ def generate_query_conditions_from_column(self, column: Any) -> Any: # Try to cast the column to a numeric type for numeric operations try: numeric_column = cast(column, Float) + if not isinstance(numeric_column, (int, float)): + raise ValueError("something went wrong!", numeric_column) if self.operation == GenericFilterOps.GT: - return numeric_column > self.value + return numeric_column > float(self.value) if self.operation == GenericFilterOps.LT: - return numeric_column < self.value + return numeric_column < float(self.value) if self.operation == GenericFilterOps.GTE: - return numeric_column >= self.value + return numeric_column >= float(self.value) if self.operation == GenericFilterOps.LTE: - return numeric_column <= self.value + return numeric_column <= float(self.value) except Exception: # Handle the exception or fallback as needed raise ValueError( "Failed to cast column to numeric type for comparison" ) - else: - raise ValueError("Invalid operation or incompatible data type") return column == self.value @@ -243,6 +243,9 @@ def _remove_hyphens_from_value(cls, value: Any) -> Any: if isinstance(value, str): return value.replace("-", "") + if isinstance(value, list): + return [str(v).replace("-", "") for v in value] + return value def generate_query_conditions_from_column(self, column: Any) -> Any: @@ -634,8 +637,18 @@ def _resolve_operator(value: Any) -> Tuple[Any, GenericFilterOps]: if operator == operator.ONEOF: try: value = json.loads(value) - except: - raise ValueError("Add some error message here....") + if not isinstance(value, list): + raise ValueError( + "When you are using the 'oneof:' filtering " + "make sure that the provided value is a json " + "formatted list." + ) + except ValueError: + raise ValueError( + "When you are using the 'oneof:' filtering " + "make sure that the provided value is a json " + "formatted list." + ) return value, operator @@ -687,8 +700,8 @@ def generate_name_or_id_query_conditions( return or_(*conditions) + @staticmethod def generate_custom_query_conditions_for_column( - self, value: Any, table: Type[SQLModel], column: str, @@ -1074,7 +1087,10 @@ def _define_uuid_filter( # For equality checks, ensure that the value is a valid UUID. if operator == GenericFilterOps.ONEOF and not isinstance(value, list): - raise ValueError("") + raise ValueError( + "If you are using `oneof:` as a filtering op, the value needs " + "to be a json formatted list string." + ) # Generate the filter. uuid_filter = UUIDFilter( From 1c051ec6475713912786803af81666e6fbfe1ca1 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 6 Nov 2024 11:58:43 +0100 Subject: [PATCH 028/124] adding tests --- .../functional/zen_stores/test_zen_store.py | 64 ++++++++++++++++++- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index a68c9fe3de7..c5713a592fc 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. +import json import os import random import time @@ -19,12 +20,12 @@ from datetime import datetime from string import ascii_lowercase from threading import Thread -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from unittest.mock import patch from uuid import UUID, uuid4 import pytest -from pydantic import SecretStr +from pydantic import SecretStr, ValidationError from sqlalchemy.exc import IntegrityError from tests.integration.functional.utils import sample_name @@ -46,6 +47,7 @@ from tests.unit.pipelines.test_build_utils import ( StubLocalRepositoryContext, ) +from zenml import Model, log_metadata, pipeline, step from zenml.artifacts.utils import ( _load_artifact_store, ) @@ -2904,6 +2906,64 @@ def test_deleting_run_deletes_steps(): assert store.list_run_steps(filter_model).total == 0 +@step +def step_to_log_metadata(metadata: Union[str, int, bool]) -> int: + log_metadata({"blupus": metadata}) + return 42 + + +@pipeline(name="aria", model=Model(name="axl"), tags=["cats", "squirrels"]) +def pipeline_to_log_metadata(metadata): + step_to_log_metadata(metadata) + + +def test_pipeline_run_filters_with_oneof_and_run_metadata(clean_client): + store = clean_client.zen_store + + metadata_values = [3, 25, 100, "random_string", True] + + runs = [] + for v in metadata_values: + runs.append(pipeline_to_log_metadata(v)) + + # Test oneof: name filtering + runs_filter = PipelineRunFilter( + name=f"oneof:{json.dumps([r.name for r in runs[:2]])}" + ) + runs = store.list_runs(runs_filter_model=runs_filter) + assert len(runs) == 2 # The first two runs + + # Test oneof: UUID filtering + runs_filter = PipelineRunFilter( + id=f"oneof:{json.dumps([str(r.id) for r in runs[:2]])}" + ) + runs = store.list_runs(runs_filter_model=runs_filter) + assert len(runs) == 2 # The first two runs + + # Test oneof: tags filtering + runs_filter = PipelineRunFilter(tag=f'oneof:["cats", "dogs"]') + runs = store.list_runs(runs_filter_model=runs_filter) + assert len(runs) == len(metadata_values) # All runs + + runs_filter = PipelineRunFilter(tag=f'oneof:["dogs"]') + runs = store.list_runs(runs_filter_model=runs_filter) + assert len(runs) == 0 # No runs + + # Test oneof: formatting + with pytest.raises(ValidationError): + PipelineRunFilter(name=f"oneof:random_value") + + # Test metadata filtering + runs_filter = PipelineRunFilter(run_metadata={"blupus": "lt:30"}) + runs = store.list_runs(runs_filter_model=runs_filter) + assert len(runs) == 2 # The run with 3 and 25 + + for r in runs: + assert "blupus" in r.run_metadata + assert isinstance(r.run_metadata["blupus"], int) + assert r.run_metadata["blupus"] < 30 + + # .--------------------. # | Pipeline run steps | # '--------------------' From e284808fd64dc61e837d05bf77d110ef9c56c6ed Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 6 Nov 2024 11:59:03 +0100 Subject: [PATCH 029/124] final touches and formatting --- src/zenml/models/v2/base/filter.py | 28 +++++++++++--------- src/zenml/models/v2/core/artifact_version.py | 5 ++-- src/zenml/models/v2/core/model_version.py | 7 ++--- src/zenml/models/v2/core/pipeline_run.py | 5 ++-- src/zenml/models/v2/core/step_run.py | 7 ++--- 5 files changed, 30 insertions(+), 22 deletions(-) diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index af1128e8047..29b94144310 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -37,7 +37,7 @@ field_validator, model_validator, ) -from sqlalchemy import asc, desc, cast, Float +from sqlalchemy import Float, and_, asc, cast, desc from sqlmodel import SQLModel from zenml.constants import ( @@ -202,25 +202,29 @@ def generate_query_conditions_from_column(self, column: Any) -> Any: GenericFilterOps.GT, GenericFilterOps.LT, GenericFilterOps.GTE, - GenericFilterOps.LTE + GenericFilterOps.LTE, }: - # Try to cast the column to a numeric type for numeric operations try: numeric_column = cast(column, Float) - if not isinstance(numeric_column, (int, float)): - raise ValueError("something went wrong!", numeric_column) if self.operation == GenericFilterOps.GT: - return numeric_column > float(self.value) + return and_( + numeric_column, numeric_column > float(self.value) + ) if self.operation == GenericFilterOps.LT: - return numeric_column < float(self.value) + return and_( + numeric_column, numeric_column < float(self.value) + ) if self.operation == GenericFilterOps.GTE: - return numeric_column >= float(self.value) + return and_( + numeric_column, numeric_column >= float(self.value) + ) if self.operation == GenericFilterOps.LTE: - return numeric_column <= float(self.value) - except Exception: - # Handle the exception or fallback as needed + return and_( + numeric_column, numeric_column <= float(self.value) + ) + except Exception as e: raise ValueError( - "Failed to cast column to numeric type for comparison" + f"Failed to cast column to numeric type for comparison: {e}" ) return column == self.value diff --git a/src/zenml/models/v2/core/artifact_version.py b/src/zenml/models/v2/core/artifact_version.py index 9345970cfd4..312fd8fb40b 100644 --- a/src/zenml/models/v2/core/artifact_version.py +++ b/src/zenml/models/v2/core/artifact_version.py @@ -533,7 +533,7 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter): ) run_metadata: Optional[Dict[str, str]] = Field( default=None, - description="The run_metadata to filter the artifact versions by." + description="The run_metadata to filter the artifact versions by.", ) model_config = ConfigDict(protected_namespaces=()) @@ -642,7 +642,8 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: for key, value in self.run_metadata.items(): additional_filter = and_( RunMetadataSchema.resource_id == ArtifactVersionSchema.id, - RunMetadataSchema.resource_type == MetadataResourceTypes.ARTIFACT_VERSION, + RunMetadataSchema.resource_type + == MetadataResourceTypes.ARTIFACT_VERSION, RunMetadataSchema.key == key, self.generate_custom_query_conditions_for_column( value=value, diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index 72af7516544..f2e3a7aa911 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -622,7 +622,7 @@ class ModelVersionFilter(WorkspaceScopedTaggableFilter): ) run_metadata: Optional[Dict[str, str]] = Field( default=None, - description="The run_metadata to filter the model versions by." + description="The run_metadata to filter the model versions by.", ) _model_id: UUID = PrivateAttr(None) @@ -656,8 +656,8 @@ def get_custom_filters( from zenml.zen_stores.schemas import ( ModelVersionSchema, - UserSchema, RunMetadataSchema, + UserSchema, ) if self.user: @@ -677,7 +677,8 @@ def get_custom_filters( for key, value in self.run_metadata.items(): additional_filter = and_( RunMetadataSchema.resource_id == ModelVersionSchema.id, - RunMetadataSchema.resource_type == MetadataResourceTypes.MODEL_VERSION, + RunMetadataSchema.resource_type + == MetadataResourceTypes.MODEL_VERSION, RunMetadataSchema.key == key, self.generate_custom_query_conditions_for_column( value=value, diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index 6b1764a8e60..8468c105bee 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -668,7 +668,7 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): ) run_metadata: Optional[Dict[str, str]] = Field( default=None, - description="The run_metadata to filter the pipeline runs by." + description="The run_metadata to filter the pipeline runs by.", ) # TODO: Remove once frontend is ready for it. This is replaced by the more # generic `pipeline` filter below. @@ -898,7 +898,8 @@ def get_custom_filters( for key, value in self.run_metadata.items(): additional_filter = and_( RunMetadataSchema.resource_id == PipelineRunSchema.id, - RunMetadataSchema.resource_type == MetadataResourceTypes.PIPELINE_RUN, + RunMetadataSchema.resource_type + == MetadataResourceTypes.PIPELINE_RUN, RunMetadataSchema.key == key, self.generate_custom_query_conditions_for_column( value=value, diff --git a/src/zenml/models/v2/core/step_run.py b/src/zenml/models/v2/core/step_run.py index cd037a01638..80c29e857cf 100644 --- a/src/zenml/models/v2/core/step_run.py +++ b/src/zenml/models/v2/core/step_run.py @@ -556,7 +556,7 @@ class StepRunFilter(WorkspaceScopedFilter): ) run_metadata: Optional[Dict[str, str]] = Field( default=None, - description="The run_metadata to filter the step runs by." + description="The run_metadata to filter the step runs by.", ) model_config = ConfigDict(protected_namespaces=()) @@ -576,8 +576,8 @@ def get_custom_filters( from zenml.zen_stores.schemas import ( ModelSchema, ModelVersionSchema, - StepRunSchema, RunMetadataSchema, + StepRunSchema, ) if self.model: @@ -595,7 +595,8 @@ def get_custom_filters( for key, value in self.run_metadata.items(): additional_filter = and_( RunMetadataSchema.resource_id == StepRunSchema.id, - RunMetadataSchema.resource_type == MetadataResourceTypes.STEP_RUN, + RunMetadataSchema.resource_type + == MetadataResourceTypes.STEP_RUN, RunMetadataSchema.key == key, self.generate_custom_query_conditions_for_column( value=value, From d5bbf7211ebc8a076d7f6138a86f24d2a9e22d50 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 6 Nov 2024 13:06:39 +0100 Subject: [PATCH 030/124] formatting --- tests/integration/functional/zen_stores/test_zen_store.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index c5713a592fc..a8cdd14e4eb 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -2941,17 +2941,17 @@ def test_pipeline_run_filters_with_oneof_and_run_metadata(clean_client): assert len(runs) == 2 # The first two runs # Test oneof: tags filtering - runs_filter = PipelineRunFilter(tag=f'oneof:["cats", "dogs"]') + runs_filter = PipelineRunFilter(tag='oneof:["cats", "dogs"]') runs = store.list_runs(runs_filter_model=runs_filter) assert len(runs) == len(metadata_values) # All runs - runs_filter = PipelineRunFilter(tag=f'oneof:["dogs"]') + runs_filter = PipelineRunFilter(tag='oneof:["dogs"]') runs = store.list_runs(runs_filter_model=runs_filter) assert len(runs) == 0 # No runs # Test oneof: formatting with pytest.raises(ValidationError): - PipelineRunFilter(name=f"oneof:random_value") + PipelineRunFilter(name="oneof:random_value") # Test metadata filtering runs_filter = PipelineRunFilter(run_metadata={"blupus": "lt:30"}) From 3a0d4c859ed6ad239df055b2120541ac0fb2162d Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 6 Nov 2024 13:07:10 +0100 Subject: [PATCH 031/124] moved the utils --- src/zenml/__init__.py | 3 +- src/zenml/pipelines/utils.py | 66 ----------------- src/zenml/utils/metadata_utils.py | 114 ++++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 68 deletions(-) delete mode 100644 src/zenml/pipelines/utils.py create mode 100644 src/zenml/utils/metadata_utils.py diff --git a/src/zenml/__init__.py b/src/zenml/__init__.py index 2fbd2231140..01f8bff1d92 100644 --- a/src/zenml/__init__.py +++ b/src/zenml/__init__.py @@ -48,7 +48,7 @@ from zenml.pipelines import get_pipeline_context, pipeline from zenml.steps import step, get_step_context from zenml.steps.utils import log_step_metadata -from zenml.pipelines.utils import log_metadata +from zenml.utils.metadata_utils import log_metadata from zenml.entrypoints import entrypoint __all__ = [ @@ -60,7 +60,6 @@ "log_artifact_metadata", "log_model_metadata", "log_step_metadata", - "log_metadata", "Model", "link_artifact_to_model", "pipeline", diff --git a/src/zenml/pipelines/utils.py b/src/zenml/pipelines/utils.py deleted file mode 100644 index b56224e80ac..00000000000 --- a/src/zenml/pipelines/utils.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) ZenML GmbH 2024. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Utility functions to run ZenML pipelines.""" - -import contextlib -from typing import Dict, Optional, Union -from uuid import UUID - -from zenml.client import Client -from zenml.enums import MetadataResourceTypes -from zenml.metadata.metadata_types import MetadataType -from zenml.steps.step_context import get_step_context - - -def log_metadata( - metadata: Dict[str, MetadataType], - run_name_id_or_prefix: Optional[Union[str, UUID]] = None, -) -> None: - """Logs metadata. - - Args: - metadata: The metadata to log. - run_name_id_or_prefix: The name, ID or prefix of the run to log metadata - for. Can be omitted when being called inside a step. - - Raises: - ValueError: If no run identifier is provided and the function is not - called from within a step. - """ - step_context = None - if not run_name_id_or_prefix: - with contextlib.suppress(RuntimeError): - step_context = get_step_context() - run_name_id_or_prefix = step_context.pipeline_run.id - - if not run_name_id_or_prefix: - raise ValueError( - "No pipeline name or ID provided and you are not running " - "within a step. Please provide a pipeline name or ID, or " - "provide a run ID." - ) - - client = Client() - if step_context is None and not isinstance(run_name_id_or_prefix, UUID): - run_name_id_or_prefix = client.get_pipeline_run( - name_id_or_prefix=run_name_id_or_prefix, - ).id - - # TODO: Should we also create the corresponding step and model metadata - # as well? - client.create_run_metadata( - metadata=metadata, - resource_id=run_name_id_or_prefix, - resource_type=MetadataResourceTypes.PIPELINE_RUN, - ) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py new file mode 100644 index 00000000000..587b7bf9c40 --- /dev/null +++ b/src/zenml/utils/metadata_utils.py @@ -0,0 +1,114 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Utility functions to handle metadata for ZenML entities.""" + +import contextlib +from typing import Dict, Optional, Union +from uuid import UUID + +from zenml.client import Client +from zenml.enums import MetadataResourceTypes +from zenml.metadata.metadata_types import MetadataType +from zenml.steps.step_context import get_step_context + + +def log_metadata( + metadata: Dict[str, MetadataType], + run: Optional[Union[str, UUID]] = None, + step: Optional[Union[str, UUID]] = None, + model: Optional[Union[str, UUID]] = None, + artifact: Optional[Union[str, UUID]] = None, +) -> None: + """Logs metadata for various resource types in a generalized way. + + Args: + metadata: The metadata to log. + run: The name, ID, or prefix of the run. + step: The name, ID, or prefix of the step. + model: The name, ID, or prefix of the model. + artifact: The name, ID, or prefix of the artifact. + + Raises: + ValueError: If no identifiers are provided and the function is not + called from within a step. + """ + client = Client() + + # Attempt to get the step context if no identifiers are provided + if not any([run, step, model, artifact]): + with contextlib.suppress(RuntimeError): + step_context = get_step_context() + if step_context: + run = step_context.pipeline_run.id + step = step_context.step_run.id + model = step_context.model_version.id + + # Raise an error if still no identifiers are available + if not any([run, step, model, artifact]): + raise ValueError( + "No valid identifiers (run, step, model, or artifact) provided " + "and not running within a step context. Please provide at least " + "one." + ) + + # Create metadata for the run, if available + if run: + if not isinstance(run, UUID): + run = client.get_pipeline_run(name_id_or_prefix=run).id + client.create_run_metadata( + metadata=metadata, + resource_id=run, + resource_type=MetadataResourceTypes.PIPELINE_RUN, + ) + + # Create metadata for the step, if available + if step: + if not isinstance(step, UUID): + assert run is not None, ( + "If you are using `log_metadata` function to log metadata " + "for a step manually, you have to provide a run name id or " + "prefix as well." + ) + step = ( + client.get_pipeline_run(name_id_or_prefix=run).steps[step].id + ) + client.create_run_metadata( + metadata=metadata, + resource_id=step, + resource_type=MetadataResourceTypes.STEP_RUN, + ) + + # Create metadata for the model, if available + if model: + if not isinstance(model, UUID): + model = client.get_model_version( + model_version_name_or_number_or_id=model + ).id + client.create_run_metadata( + metadata=metadata, + resource_id=model, + resource_type=MetadataResourceTypes.MODEL_VERSION, + ) + + # Create metadata for the artifact, if available + if artifact: + if not isinstance(artifact, UUID): + artifact = client.get_artifact_version( + name_id_or_prefix=artifact + ).id + client.create_run_metadata( + metadata=metadata, + resource_id=artifact, + resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + ) From 5b3b2171687e9820fe6c7fb6ec2a96df10b8992b Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 6 Nov 2024 14:30:55 +0100 Subject: [PATCH 032/124] modified log metadata function --- src/zenml/utils/metadata_utils.py | 167 +++++++++++++++++++----------- 1 file changed, 105 insertions(+), 62 deletions(-) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 587b7bf9c40..79c7d1c14d1 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -14,7 +14,7 @@ """Utility functions to handle metadata for ZenML entities.""" import contextlib -from typing import Dict, Optional, Union +from typing import Dict, Optional, Set, Union from uuid import UUID from zenml.client import Client @@ -27,8 +27,8 @@ def log_metadata( metadata: Dict[str, MetadataType], run: Optional[Union[str, UUID]] = None, step: Optional[Union[str, UUID]] = None, - model: Optional[Union[str, UUID]] = None, - artifact: Optional[Union[str, UUID]] = None, + model_version: Optional[Union[str, UUID]] = None, + artifact_version: Optional[Union[str, UUID]] = None, ) -> None: """Logs metadata for various resource types in a generalized way. @@ -36,8 +36,8 @@ def log_metadata( metadata: The metadata to log. run: The name, ID, or prefix of the run. step: The name, ID, or prefix of the step. - model: The name, ID, or prefix of the model. - artifact: The name, ID, or prefix of the artifact. + model_version: The name, ID, or prefix of the model version. + artifact_version: The name, ID, or prefix of the artifact version. Raises: ValueError: If no identifiers are provided and the function is not @@ -45,70 +45,113 @@ def log_metadata( """ client = Client() - # Attempt to get the step context if no identifiers are provided - if not any([run, step, model, artifact]): + if not any([run, step, model_version, artifact_version]): + # Executing without any identifiers -> Fetch the step context with contextlib.suppress(RuntimeError): step_context = get_step_context() if step_context: - run = step_context.pipeline_run.id - step = step_context.step_run.id - model = step_context.model_version.id + client.create_run_metadata( + metadata=metadata, + resource_id=step_context.pipeline_run.id, + resource_type=MetadataResourceTypes.PIPELINE_RUN, + ) + client.create_run_metadata( + metadata=metadata, + resource_id=step_context.step_run.id, + resource_type=MetadataResourceTypes.STEP_RUN, + ) + if step_context.model_version: + client.create_run_metadata( + metadata=metadata, + resource_id=step_context.model_version.id, + resource_type=MetadataResourceTypes.MODEL_VERSION, + ) + else: + raise ValueError( + "No valid identifiers (run, step, model, or artifact) " + "provided and not running within a step context. Please " + "provide at least one." + ) + else: + # Executing outside a step execution + metadata_batch: Dict[MetadataResourceTypes, Set[UUID]] = { + MetadataResourceTypes.PIPELINE_RUN: set(), + MetadataResourceTypes.STEP_RUN: set(), + MetadataResourceTypes.ARTIFACT_VERSION: set(), + MetadataResourceTypes.MODEL_VERSION: set(), + } - # Raise an error if still no identifiers are available - if not any([run, step, model, artifact]): - raise ValueError( - "No valid identifiers (run, step, model, or artifact) provided " - "and not running within a step context. Please provide at least " - "one." - ) + if step: + # If a step identifier is provided, try to fetch it. If the + # identifier is a UUID, we can directly fetch it. If the identifier + # is a step name, we also need a run identifier. + if not isinstance(step, UUID): + assert run is not None, ( + "If you are using `log_metadata` function to log metadata " + "for a specific step by name, you need to provide an " + "identifier for the pipeline run it belongs to." + ) + run_model = client.get_pipeline_run(name_id_or_prefix=run) + step_model = run_model.steps[step] + else: + step_model = client.get_run_step(step_run_id=step) + run_model = client.get_pipeline_run( + name_id_or_prefix=step_model.pipeline_run_id + ) - # Create metadata for the run, if available - if run: - if not isinstance(run, UUID): - run = client.get_pipeline_run(name_id_or_prefix=run).id - client.create_run_metadata( - metadata=metadata, - resource_id=run, - resource_type=MetadataResourceTypes.PIPELINE_RUN, - ) + metadata_batch[MetadataResourceTypes.PIPELINE_RUN].add( + run_model.id + ) + metadata_batch[MetadataResourceTypes.STEP_RUN].add(step_model.id) + if step_model.model_version: + metadata_batch[MetadataResourceTypes.MODEL_VERSION].add( + step_model.model_version.id + ) + + if run: + # If a run identifier is provided, try to fetch it. We may have + # already fetched it and added it to the batch, when we were + # handling the step. In order to avoid duplicate calls, the + # metadata_batch is being used. + run_model = client.get_pipeline_run(name_id_or_prefix=run) - # Create metadata for the step, if available - if step: - if not isinstance(step, UUID): - assert run is not None, ( - "If you are using `log_metadata` function to log metadata " - "for a step manually, you have to provide a run name id or " - "prefix as well." + if run_model.model_version: + metadata_batch[MetadataResourceTypes.MODEL_VERSION].add( + run_model.model_version.id + ) + metadata_batch[MetadataResourceTypes.PIPELINE_RUN].add( + run_model.id ) - step = ( - client.get_pipeline_run(name_id_or_prefix=run).steps[step].id + + if model_version_id := model_version: + # If a model version identifier is provided, try to fetch it. It is + # possible that we have already fetched this model version when + # we are dealing with the step and run. In order to duplications, + # the metadata_batch is being used. + if not isinstance(model_version_id, UUID): + model_version_id = client.get_model_version( + model_version_name_or_number_or_id=model_version + ).id + metadata_batch[MetadataResourceTypes.MODEL_VERSION].add( + model_version_id ) - client.create_run_metadata( - metadata=metadata, - resource_id=step, - resource_type=MetadataResourceTypes.STEP_RUN, - ) - # Create metadata for the model, if available - if model: - if not isinstance(model, UUID): - model = client.get_model_version( - model_version_name_or_number_or_id=model - ).id - client.create_run_metadata( - metadata=metadata, - resource_id=model, - resource_type=MetadataResourceTypes.MODEL_VERSION, - ) + if artifact_version_id := artifact_version: + # If an artifact version identifier is provided, try to fetch it and + # add it to the batch. + if not isinstance(artifact_version_id, UUID): + artifact_version_id = client.get_artifact_version( + name_id_or_prefix=artifact_version + ).id + metadata_batch[MetadataResourceTypes.ARTIFACT_VERSION].add( + artifact_version_id + ) - # Create metadata for the artifact, if available - if artifact: - if not isinstance(artifact, UUID): - artifact = client.get_artifact_version( - name_id_or_prefix=artifact - ).id - client.create_run_metadata( - metadata=metadata, - resource_id=artifact, - resource_type=MetadataResourceTypes.ARTIFACT_VERSION, - ) + # Create the run metadata + for resource_type, resource_ids in metadata_batch.items(): + for resource_id in resource_ids: + client.create_run_metadata( + metadata=metadata, + resource_id=resource_id, + resource_type=resource_type, + ) From 3d5a9f0cc6bb8f74fa34fae719af77ec7ce39d69 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 6 Nov 2024 16:19:50 +0100 Subject: [PATCH 033/124] checkpoint --- src/zenml/utils/metadata_utils.py | 334 ++++++++++++++++++++---------- 1 file changed, 224 insertions(+), 110 deletions(-) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 79c7d1c14d1..f3d17768879 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -13,31 +13,102 @@ # permissions and limitations under the License. """Utility functions to handle metadata for ZenML entities.""" -import contextlib -from typing import Dict, Optional, Set, Union +from typing import Dict, Optional, Set, Union, overload from uuid import UUID from zenml.client import Client from zenml.enums import MetadataResourceTypes +from zenml.logger import get_logger from zenml.metadata.metadata_types import MetadataType from zenml.steps.step_context import get_step_context +logger = get_logger(__name__) + +@overload +def log_metadata(metadata: Dict[str, MetadataType]) -> None: ... + + +@overload +def log_metadata( + metadata: Dict[str, MetadataType], + artifact_version_id: UUID, +) -> None: ... + + +@overload +def log_metadata( + metadata: Dict[str, MetadataType], + artifact_name: str, + artifact_version: Optional[str] = None, +) -> None: ... + + +@overload +def log_metadata( + metadata: Dict[str, MetadataType], + model_version_id: UUID, +) -> None: ... + + +@overload def log_metadata( metadata: Dict[str, MetadataType], - run: Optional[Union[str, UUID]] = None, - step: Optional[Union[str, UUID]] = None, - model_version: Optional[Union[str, UUID]] = None, - artifact_version: Optional[Union[str, UUID]] = None, + model_name: str, + model_version: str, +) -> None: ... + + +@overload +def log_metadata( + metadata: Dict[str, MetadataType], + step_id: UUID, +) -> None: ... + + +@overload +def log_metadata( + metadata: Dict[str, MetadataType], + run_id_name_or_prefix: Union[UUID, str], +) -> None: ... + + +@overload +def log_metadata( + metadata: Dict[str, MetadataType], + step_name: str, + run_id_name_or_prefix: Union[UUID, str], +) -> None: ... + + +def log_metadata( + metadata: Dict[str, MetadataType], + # Parameters to manually log metadata for steps and runs + step_id: Optional[UUID] = None, + step_name: Optional[str] = None, + run_id_name_or_prefix: Optional[Union[UUID, str]] = None, + # Parameters to manually log metadata for artifacts + artifact_version_id: Optional[UUID] = None, + artifact_name: Optional[str] = None, + artifact_version: Optional[str] = None, + # Parameters to manually log metadata for models + model_version_id: Optional[UUID] = None, + model_name: Optional[str] = None, + model_version: Optional[str] = None, ) -> None: """Logs metadata for various resource types in a generalized way. Args: metadata: The metadata to log. - run: The name, ID, or prefix of the run. - step: The name, ID, or prefix of the step. - model_version: The name, ID, or prefix of the model version. - artifact_version: The name, ID, or prefix of the artifact version. + step_id: The ID of the step. + step_name: The name of the step. + run_id_name_or_prefix: The id, name or prefix of the run + artifact_version_id: The ID of the artifact version + artifact_name: The name of the artifact. + artifact_version: The version of the artifact. + model_version_id: The ID of the model version. + model_name: The name of the model. + model_version: The version of the model Raises: ValueError: If no identifiers are provided and the function is not @@ -45,113 +116,156 @@ def log_metadata( """ client = Client() - if not any([run, step, model_version, artifact_version]): - # Executing without any identifiers -> Fetch the step context - with contextlib.suppress(RuntimeError): - step_context = get_step_context() - if step_context: - client.create_run_metadata( - metadata=metadata, - resource_id=step_context.pipeline_run.id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, - ) - client.create_run_metadata( - metadata=metadata, - resource_id=step_context.step_run.id, - resource_type=MetadataResourceTypes.STEP_RUN, - ) - if step_context.model_version: - client.create_run_metadata( - metadata=metadata, - resource_id=step_context.model_version.id, - resource_type=MetadataResourceTypes.MODEL_VERSION, - ) - else: - raise ValueError( - "No valid identifiers (run, step, model, or artifact) " - "provided and not running within a step context. Please " - "provide at least one." - ) - else: - # Executing outside a step execution - metadata_batch: Dict[MetadataResourceTypes, Set[UUID]] = { - MetadataResourceTypes.PIPELINE_RUN: set(), - MetadataResourceTypes.STEP_RUN: set(), - MetadataResourceTypes.ARTIFACT_VERSION: set(), - MetadataResourceTypes.MODEL_VERSION: set(), - } - - if step: - # If a step identifier is provided, try to fetch it. If the - # identifier is a UUID, we can directly fetch it. If the identifier - # is a step name, we also need a run identifier. - if not isinstance(step, UUID): - assert run is not None, ( - "If you are using `log_metadata` function to log metadata " - "for a specific step by name, you need to provide an " - "identifier for the pipeline run it belongs to." - ) - run_model = client.get_pipeline_run(name_id_or_prefix=run) - step_model = run_model.steps[step] - else: - step_model = client.get_run_step(step_run_id=step) - run_model = client.get_pipeline_run( - name_id_or_prefix=step_model.pipeline_run_id - ) - - metadata_batch[MetadataResourceTypes.PIPELINE_RUN].add( - run_model.id - ) - metadata_batch[MetadataResourceTypes.STEP_RUN].add(step_model.id) - if step_model.model_version: - metadata_batch[MetadataResourceTypes.MODEL_VERSION].add( - step_model.model_version.id - ) + # Initialize a batch of request to avoid duplications + metadata_batch: Dict[MetadataResourceTypes, Set[UUID]] = { + MetadataResourceTypes.PIPELINE_RUN: set(), + MetadataResourceTypes.STEP_RUN: set(), + MetadataResourceTypes.ARTIFACT_VERSION: set(), + MetadataResourceTypes.MODEL_VERSION: set(), + } - if run: - # If a run identifier is provided, try to fetch it. We may have - # already fetched it and added it to the batch, when we were - # handling the step. In order to avoid duplicate calls, the - # metadata_batch is being used. - run_model = client.get_pipeline_run(name_id_or_prefix=run) + # If a step name is provided, we need a run_id_name_or_prefix and will log + # metadata for the steps pipeline and model accordingly. + if step_name is not None and run_id_name_or_prefix is not None: + run_model = client.get_pipeline_run( + name_id_or_prefix=run_id_name_or_prefix + ) + step_model = run_model.steps[step_name] - if run_model.model_version: - metadata_batch[MetadataResourceTypes.MODEL_VERSION].add( - run_model.model_version.id - ) - metadata_batch[MetadataResourceTypes.PIPELINE_RUN].add( - run_model.id + metadata_batch[MetadataResourceTypes.PIPELINE_RUN].add(run_model.id) + metadata_batch[MetadataResourceTypes.STEP_RUN].add(step_model.id) + if step_model.model_version: + metadata_batch[MetadataResourceTypes.MODEL_VERSION].add( + step_model.model_version.id ) - if model_version_id := model_version: - # If a model version identifier is provided, try to fetch it. It is - # possible that we have already fetched this model version when - # we are dealing with the step and run. In order to duplications, - # the metadata_batch is being used. - if not isinstance(model_version_id, UUID): - model_version_id = client.get_model_version( - model_version_name_or_number_or_id=model_version - ).id + # If a step is identified by id, fetch it directly through the client, + # follow a similar procedure and log metadata for its pipeline and model + # as well. + elif step_id is not None: + step_model = client.get_run_step(step_run_id=step_id) + run_model = client.get_pipeline_run( + name_id_or_prefix=step_model.pipeline_run_id + ) + metadata_batch[MetadataResourceTypes.PIPELINE_RUN].add(run_model.id) + metadata_batch[MetadataResourceTypes.STEP_RUN].add(step_model.id) + if step_model.model_version: metadata_batch[MetadataResourceTypes.MODEL_VERSION].add( - model_version_id + step_model.model_version.id ) - if artifact_version_id := artifact_version: - # If an artifact version identifier is provided, try to fetch it and - # add it to the batch. - if not isinstance(artifact_version_id, UUID): - artifact_version_id = client.get_artifact_version( - name_id_or_prefix=artifact_version - ).id - metadata_batch[MetadataResourceTypes.ARTIFACT_VERSION].add( - artifact_version_id + # If a pipeline run id is identified, we need to log metadata to it and its + # model as well. + elif run_id_name_or_prefix is not None: + run_model = client.get_pipeline_run( + name_id_or_prefix=run_id_name_or_prefix + ) + if run_model.model_version: + metadata_batch[MetadataResourceTypes.MODEL_VERSION].add( + run_model.model_version.id ) + metadata_batch[MetadataResourceTypes.PIPELINE_RUN].add(run_model.id) - # Create the run metadata - for resource_type, resource_ids in metadata_batch.items(): - for resource_id in resource_ids: - client.create_run_metadata( - metadata=metadata, - resource_id=resource_id, - resource_type=resource_type, + # If the user provides a model name and version, we use to model abstraction + # to fetch the model version and attach the corresponding metadata to it. + elif model_name is not None and model_version is not None: + from zenml import Model + + mv = Model(name=model_name, version=model_version) + metadata_batch[MetadataResourceTypes.MODEL_VERSION].add(mv.id) + + # If the user provides a model version id, we use the client to fetch it and + # attach the metadata to it. + elif model_version_id is not None: + model_version_id = client.get_model_version( + model_version_name_or_number_or_id=model_version_id + ).id + metadata_batch[MetadataResourceTypes.MODEL_VERSION].add( + model_version_id + ) + + # If the user provides an artifact name, there are two possibilities. If + # an artifact version is also provided with the name, we use both to fetch + # the artifact version and use it to log the metadata. If no version is + # provided, we make sure that the call is happening within a step, otherwise + # we fail. + elif artifact_name is not None: + if artifact_version: + artifact_version_model = client.get_artifact_version( + name_id_or_prefix=artifact_name, version=artifact_version + ) + client.create_run_metadata( + metadata=metadata, + resource_id=artifact_version_model.id, + resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + ) + else: + try: + step_context = get_step_context() + step_context.add_output_metadata( + metadata=metadata, output_name=artifact_name + ) + except RuntimeError: + raise ValueError( + "You are calling 'log_metadata(artifact_name='...') " + "without specifying a version outside of a step execution." ) + + # If the user directly provides an artifact_version_id, we use the client to + # fetch is and attach the metadata accordingly. + elif artifact_version_id is not None: + artifact_version_model = client.get_artifact_version( + name_id_or_prefix=artifact_version_id, + ) + client.create_run_metadata( + metadata=metadata, + resource_id=artifact_version_model.id, + resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + ) + + # If every additional value is None, that means we are calling it bare bones + # and this call needs to happen during a step execution. We will use the + # step context to fetch the step, run and possibly the model version and + # attach the metadata accordingly. + elif all( + v is None + for v in [ + step_id, + step_name, + run_id_name_or_prefix, + artifact_version_id, + artifact_name, + artifact_version, + model_version_id, + model_name, + model_version, + ] + ): + try: + step_context = get_step_context() + except RuntimeError: + raise ValueError( + "You are calling 'log_metadata()' outside of a step execution. " + "If you would like to add metadata to a ZenML entity outside " + "of the step execution, please provide the required " + "identifiers." + ) + client.create_run_metadata( + metadata=metadata, + resource_id=step_context.pipeline_run.id, + resource_type=MetadataResourceTypes.PIPELINE_RUN, + ) + client.create_run_metadata( + metadata=metadata, + resource_id=step_context.step_run.id, + resource_type=MetadataResourceTypes.STEP_RUN, + ) + if step_context.model_version: + client.create_run_metadata( + metadata=metadata, + resource_id=step_context.model_version.id, + resource_type=MetadataResourceTypes.MODEL_VERSION, + ) + + else: + raise ValueError() \ No newline at end of file From e3079a3391b6ccdc0a58decd2ed9abf07ff62741 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 6 Nov 2024 16:49:29 +0100 Subject: [PATCH 034/124] deprecating the old functions --- src/zenml/artifacts/utils.py | 6 +++++- src/zenml/model/utils.py | 5 +++++ src/zenml/steps/utils.py | 5 +++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/zenml/artifacts/utils.py b/src/zenml/artifacts/utils.py index fde0b1d4c69..200e49896a1 100644 --- a/src/zenml/artifacts/utils.py +++ b/src/zenml/artifacts/utils.py @@ -352,7 +352,7 @@ def log_artifact_metadata( not provided, when being called inside a step that produces an artifact named `artifact_name`, the metadata will be associated to the corresponding newly created artifact. Or, if not provided when - being called outside of a step, or in a step that does not produce + being called outside a step, or in a step that does not produce any artifact named `artifact_name`, the metadata will be associated to the latest version of that artifact. @@ -361,6 +361,10 @@ def log_artifact_metadata( called inside a step with a single output, or, if neither an artifact nor an output with the given name exists. """ + logger.warning( + "The `log_artifact_metadata` function is deprecated and will soon be " + "removed. Please use `log_metadata` instead." + ) try: step_context = get_step_context() in_step_outputs = (artifact_name in step_context._outputs) or ( diff --git a/src/zenml/model/utils.py b/src/zenml/model/utils.py index 2593b606d17..5ec2123098f 100644 --- a/src/zenml/model/utils.py +++ b/src/zenml/model/utils.py @@ -56,6 +56,11 @@ def log_model_metadata( ValueError: If no model name/version is provided and the function is not called inside a step with configured `model` in decorator. """ + logger.warning( + "The `log_model_metadata` function is deprecated and will soon be " + "removed. Please use `log_metadata` instead." + ) + if model_name and model_version: from zenml import Model diff --git a/src/zenml/steps/utils.py b/src/zenml/steps/utils.py index 40324c59200..20d838b9a35 100644 --- a/src/zenml/steps/utils.py +++ b/src/zenml/steps/utils.py @@ -438,6 +438,11 @@ def log_step_metadata( from within a step or if no pipeline name or ID is provided and the function is not called from within a step. """ + logger.warning( + "The `log_step_metadata` function is deprecated and will soon be " + "removed. Please use `log_metadata` instead." + ) + step_context = None if not step_name: with contextlib.suppress(RuntimeError): From c3e69c28172773cdbb5e2a1ceea0c7a96183c35f Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 6 Nov 2024 16:57:12 +0100 Subject: [PATCH 035/124] linting and final fixes --- src/zenml/models/v2/base/filter.py | 5 ++++- src/zenml/utils/metadata_utils.py | 10 +++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index 29b94144310..b0e414fec92 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -206,6 +206,9 @@ def generate_query_conditions_from_column(self, column: Any) -> Any: }: try: numeric_column = cast(column, Float) + + assert self.value is not None + if self.operation == GenericFilterOps.GT: return and_( numeric_column, numeric_column > float(self.value) @@ -224,7 +227,7 @@ def generate_query_conditions_from_column(self, column: Any) -> Any: ) except Exception as e: raise ValueError( - f"Failed to cast column to numeric type for comparison: {e}" + f"Failed to compare the column to the numeric value: {e}" ) return column == self.value diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index f3d17768879..f5108047766 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -31,6 +31,7 @@ def log_metadata(metadata: Dict[str, MetadataType]) -> None: ... @overload def log_metadata( + *, metadata: Dict[str, MetadataType], artifact_version_id: UUID, ) -> None: ... @@ -38,6 +39,7 @@ def log_metadata( @overload def log_metadata( + *, metadata: Dict[str, MetadataType], artifact_name: str, artifact_version: Optional[str] = None, @@ -46,6 +48,7 @@ def log_metadata( @overload def log_metadata( + *, metadata: Dict[str, MetadataType], model_version_id: UUID, ) -> None: ... @@ -53,6 +56,7 @@ def log_metadata( @overload def log_metadata( + *, metadata: Dict[str, MetadataType], model_name: str, model_version: str, @@ -61,6 +65,7 @@ def log_metadata( @overload def log_metadata( + *, metadata: Dict[str, MetadataType], step_id: UUID, ) -> None: ... @@ -68,6 +73,7 @@ def log_metadata( @overload def log_metadata( + *, metadata: Dict[str, MetadataType], run_id_name_or_prefix: Union[UUID, str], ) -> None: ... @@ -75,6 +81,7 @@ def log_metadata( @overload def log_metadata( + *, metadata: Dict[str, MetadataType], step_name: str, run_id_name_or_prefix: Union[UUID, str], @@ -114,6 +121,7 @@ def log_metadata( ValueError: If no identifiers are provided and the function is not called from within a step. """ + # Initialize the client client = Client() # Initialize a batch of request to avoid duplications @@ -268,4 +276,4 @@ def log_metadata( ) else: - raise ValueError() \ No newline at end of file + raise ValueError() From 2d4c723bd81b206add97f495e9d4e13819d6d5bb Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 6 Nov 2024 16:59:48 +0100 Subject: [PATCH 036/124] better error message --- src/zenml/models/v2/base/filter.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index b0e414fec92..43792314c5c 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -1126,7 +1126,10 @@ def _define_str_filter( """ # For equality checks, ensure that the value is a valid UUID. if operator == GenericFilterOps.ONEOF and not isinstance(value, list): - raise ValueError("") + raise ValueError( + "If you are using `oneof:` as a filtering op, the value needs " + "to be a json formatted list string." + ) # Generate the filter. str_filter = StrFilter( From fbd02008ccbc7a0a9f45657beba27e71002bf08c Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 8 Nov 2024 14:24:45 +0100 Subject: [PATCH 037/124] fixing the client method --- src/zenml/client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/zenml/client.py b/src/zenml/client.py index b32b8cce04c..cd9fa28446e 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -3794,6 +3794,7 @@ def list_pipeline_runs( templatable: Optional[bool] = None, tag: Optional[str] = None, user: Optional[Union[UUID, str]] = None, + run_metadata: Optional[Dict[str, str]] = None, pipeline: Optional[Union[UUID, str]] = None, code_repository: Optional[Union[UUID, str]] = None, model: Optional[Union[UUID, str]] = None, @@ -3872,6 +3873,7 @@ def list_pipeline_runs( tag=tag, unlisted=unlisted, user=user, + run_metadata=run_metadata, pipeline=pipeline, code_repository=code_repository, stack=stack, From ec7dc02021b5fe00a7f1e2a6aeaa442701d22ac0 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 8 Nov 2024 14:33:57 +0100 Subject: [PATCH 038/124] better error message --- src/zenml/utils/metadata_utils.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index f5108047766..61ff8ad2c7c 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -276,4 +276,31 @@ def log_metadata( ) else: - raise ValueError() + raise ValueError( + """ + Unsupported way to call the `log_metadata`. Possible combinations " + include: + + # Inside a step + # Logs the metadata to the step, its run and possibly its model + log_metadata(metadata={}) + + # Manually logging for a step + # Logs the metadata to the step, its run and possibly its model + log_metadata(metadata={}, step_name=..., run_id_name_or_prefix=...) + log_metadata(metadata={}, step_id=...) + + # Manually logging for a run + # Logs the metadata to the run, possibly its model + log_metadata(metadata={}, run_id_name_or_prefix=...) + + # Manually logging for a model + log_metadata(metadata={}, model_name=..., model_version=...) + log_metadata(metadata={}, model_version_id=...) + + # Manually logging for an artifact + log_metadata(metadata={}, artifact_name=...) # inside a step + log_metadata(metadata={}, artifact_name=..., artifact_version=...) + log_metadata(metadata={}, artifact_version_id=...) + """ + ) From 1fafb7e7218ef4664c2503168ae785ca4743dcec Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 8 Nov 2024 14:37:44 +0100 Subject: [PATCH 039/124] consistent creation\ --- src/zenml/utils/metadata_utils.py | 68 +++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 22 deletions(-) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 61ff8ad2c7c..76af71ac1a5 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -124,14 +124,6 @@ def log_metadata( # Initialize the client client = Client() - # Initialize a batch of request to avoid duplications - metadata_batch: Dict[MetadataResourceTypes, Set[UUID]] = { - MetadataResourceTypes.PIPELINE_RUN: set(), - MetadataResourceTypes.STEP_RUN: set(), - MetadataResourceTypes.ARTIFACT_VERSION: set(), - MetadataResourceTypes.MODEL_VERSION: set(), - } - # If a step name is provided, we need a run_id_name_or_prefix and will log # metadata for the steps pipeline and model accordingly. if step_name is not None and run_id_name_or_prefix is not None: @@ -140,11 +132,21 @@ def log_metadata( ) step_model = run_model.steps[step_name] - metadata_batch[MetadataResourceTypes.PIPELINE_RUN].add(run_model.id) - metadata_batch[MetadataResourceTypes.STEP_RUN].add(step_model.id) + client.create_run_metadata( + metadata=metadata, + resource_id=run_model.id, + resource_type=MetadataResourceTypes.PIPELINE_RUN, + ) + client.create_run_metadata( + metadata=metadata, + resource_id=step_model.id, + resource_type=MetadataResourceTypes.STEP_RUN, + ) if step_model.model_version: - metadata_batch[MetadataResourceTypes.MODEL_VERSION].add( - step_model.model_version.id + client.create_run_metadata( + metadata=metadata, + resource_id=step_model.model_version.id, + resource_type=MetadataResourceTypes.MODEL_VERSION, ) # If a step is identified by id, fetch it directly through the client, @@ -155,11 +157,21 @@ def log_metadata( run_model = client.get_pipeline_run( name_id_or_prefix=step_model.pipeline_run_id ) - metadata_batch[MetadataResourceTypes.PIPELINE_RUN].add(run_model.id) - metadata_batch[MetadataResourceTypes.STEP_RUN].add(step_model.id) + client.create_run_metadata( + metadata=metadata, + resource_id=run_model.id, + resource_type=MetadataResourceTypes.PIPELINE_RUN, + ) + client.create_run_metadata( + metadata=metadata, + resource_id=step_model.id, + resource_type=MetadataResourceTypes.STEP_RUN, + ) if step_model.model_version: - metadata_batch[MetadataResourceTypes.MODEL_VERSION].add( - step_model.model_version.id + client.create_run_metadata( + metadata=metadata, + resource_id=step_model.model_version.id, + resource_type=MetadataResourceTypes.MODEL_VERSION, ) # If a pipeline run id is identified, we need to log metadata to it and its @@ -168,11 +180,17 @@ def log_metadata( run_model = client.get_pipeline_run( name_id_or_prefix=run_id_name_or_prefix ) + client.create_run_metadata( + metadata=metadata, + resource_id=run_model.id, + resource_type=MetadataResourceTypes.PIPELINE_RUN, + ) if run_model.model_version: - metadata_batch[MetadataResourceTypes.MODEL_VERSION].add( - run_model.model_version.id + client.create_run_metadata( + metadata=metadata, + resource_id=run_model.model_version.id, + resource_type=MetadataResourceTypes.MODEL_VERSION, ) - metadata_batch[MetadataResourceTypes.PIPELINE_RUN].add(run_model.id) # If the user provides a model name and version, we use to model abstraction # to fetch the model version and attach the corresponding metadata to it. @@ -180,7 +198,11 @@ def log_metadata( from zenml import Model mv = Model(name=model_name, version=model_version) - metadata_batch[MetadataResourceTypes.MODEL_VERSION].add(mv.id) + client.create_run_metadata( + metadata=metadata, + resource_id=mv.id, + resource_type=MetadataResourceTypes.MODEL_VERSION, + ) # If the user provides a model version id, we use the client to fetch it and # attach the metadata to it. @@ -188,8 +210,10 @@ def log_metadata( model_version_id = client.get_model_version( model_version_name_or_number_or_id=model_version_id ).id - metadata_batch[MetadataResourceTypes.MODEL_VERSION].add( - model_version_id + client.create_run_metadata( + metadata=metadata, + resource_id=model_version_id, + resource_type=MetadataResourceTypes.MODEL_VERSION, ) # If the user provides an artifact name, there are two possibilities. If From d90f55d7d97851352317c163db53bb91cfd973b8 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 8 Nov 2024 14:52:05 +0100 Subject: [PATCH 040/124] adjusting tests --- .../functional/artifacts/test_utils.py | 36 ++++++++++--------- .../pipelines/test_pipeline_context.py | 4 +-- .../functional/steps/test_step_context.py | 4 +-- tests/integration/functional/test_client.py | 6 ++-- 4 files changed, 26 insertions(+), 24 deletions(-) diff --git a/tests/integration/functional/artifacts/test_utils.py b/tests/integration/functional/artifacts/test_utils.py index 2421f55fe38..938cd7abba1 100644 --- a/tests/integration/functional/artifacts/test_utils.py +++ b/tests/integration/functional/artifacts/test_utils.py @@ -14,7 +14,7 @@ from zenml import ( load_artifact, - log_artifact_metadata, + log_metadata, pipeline, save_artifact, step, @@ -120,23 +120,25 @@ def _load_pipeline(expected_value, name, version): ) -def test_log_artifact_metadata_existing(clean_client): +def test_log_metadata_existing(clean_client): """Test logging artifact metadata for existing artifacts.""" save_artifact(42, "meaning_of_life") - log_artifact_metadata( - {"description": "Aria is great!"}, artifact_name="meaning_of_life" + log_metadata( + metadata={"description": "Aria is great!"}, + artifact_name="meaning_of_life", ) save_artifact(43, "meaning_of_life", version="43") - log_artifact_metadata( - {"description_2": "Blupus is great!"}, artifact_name="meaning_of_life" + log_metadata( + metadata={"description_2": "Blupus is great!"}, + artifact_name="meaning_of_life", ) - log_artifact_metadata( - {"description_3": "Axl is great!"}, + log_metadata( + metadata={"description_3": "Axl is great!"}, artifact_name="meaning_of_life", artifact_version="1", ) - log_artifact_metadata( - { + log_metadata( + metadata={ "float": 1.0, "int": 1, "str": "1.0", @@ -183,11 +185,11 @@ def artifact_metadata_logging_step() -> str: "description": "Aria is great!", "metrics": {"accuracy": 0.9}, } - log_artifact_metadata(output_metadata) + log_metadata(metadata=output_metadata) return "42" -def test_log_artifact_metadata_single_output(clean_client): +def test_log_metadata_single_output(clean_client): """Test logging artifact metadata for a single output.""" @pipeline @@ -212,11 +214,11 @@ def artifact_multi_output_metadata_logging_step() -> ( "description": "Blupus is great!", "metrics": {"accuracy": 0.9}, } - log_artifact_metadata(metadata=output_metadata, artifact_name="int_output") + log_metadata(metadata=output_metadata, artifact_name="int_output") return "42", 42 -def test_log_artifact_metadata_multi_output(clean_client): +def test_log_metadata_multi_output(clean_client): """Test logging artifact metadata for multiple outputs.""" @pipeline @@ -245,14 +247,14 @@ def wrong_artifact_multi_output_metadata_logging_step() -> ( "description": "Axl is great!", "metrics": {"accuracy": 0.9}, } - log_artifact_metadata(output_metadata) + log_metadata(output_metadata) return "42", 42 -def test_log_artifact_metadata_raises_error_if_output_name_unclear( +def test_log_metadata_raises_error_if_output_name_unclear( clean_client, ): - """Test that `log_artifact_metadata` raises an error if the output name is unclear.""" + """Test that `log_metadata` raises an error if the output name is unclear.""" @pipeline def artifact_metadata_logging_pipeline(): diff --git a/tests/integration/functional/pipelines/test_pipeline_context.py b/tests/integration/functional/pipelines/test_pipeline_context.py index 978b0c81844..e6d3d2d634f 100644 --- a/tests/integration/functional/pipelines/test_pipeline_context.py +++ b/tests/integration/functional/pipelines/test_pipeline_context.py @@ -10,7 +10,7 @@ pipeline, step, ) -from zenml.artifacts.utils import log_artifact_metadata +from zenml.artifacts.utils import log_metadata from zenml.client import Client @@ -103,7 +103,7 @@ def producer() -> Annotated[str, "bar"]: """Produce artifact with metadata and attach metadata to model version.""" ver = get_step_context().model.version log_model_metadata(metadata={"foobar": "model_meta_" + ver}) - log_artifact_metadata(metadata={"foobar": "artifact_meta_" + ver}) + log_metadata(metadata={"foobar": "artifact_meta_" + ver}) return "artifact_data_" + ver diff --git a/tests/integration/functional/steps/test_step_context.py b/tests/integration/functional/steps/test_step_context.py index 4442f84b08e..0291aa68146 100644 --- a/tests/integration/functional/steps/test_step_context.py +++ b/tests/integration/functional/steps/test_step_context.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from typing_extensions import Annotated -from zenml import get_step_context, log_artifact_metadata, pipeline, step +from zenml import get_step_context, log_metadata, pipeline, step from zenml.artifacts.artifact_config import ArtifactConfig from zenml.client import Client from zenml.enums import ArtifactType @@ -92,7 +92,7 @@ def _simple_step_pipeline(): @step def output_metadata_logging_step() -> Annotated[int, "my_output"]: - log_artifact_metadata(metadata={"some_key": "some_value"}) + log_metadata(metadata={"some_key": "some_value"}) return 42 diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index 9daa823a1a7..e1d2f530075 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -32,7 +32,7 @@ ExternalArtifact, get_pipeline_context, get_step_context, - log_artifact_metadata, + log_metadata, log_model_metadata, pipeline, save_artifact, @@ -968,7 +968,7 @@ def lazy_producer_test_artifact() -> Annotated[str, "new_one"]: """Produce artifact with metadata.""" from zenml.client import Client - log_artifact_metadata(metadata={"some_meta": "meta_new_one"}) + log_metadata(metadata={"some_meta": "meta_new_one"}) client = Client() @@ -1132,7 +1132,7 @@ def dummy(): save_artifact( data="body_preexisting", name="preexisting", version="1.2.3" ) - log_artifact_metadata( + log_metadata( metadata={"some_meta": "meta_preexisting"}, artifact_name="preexisting", artifact_version="1.2.3", From e0db4183ffe9a259e8bd16e6142a974f8440dd75 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 8 Nov 2024 14:52:12 +0100 Subject: [PATCH 041/124] linting --- src/zenml/utils/metadata_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 76af71ac1a5..735da38fea8 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Utility functions to handle metadata for ZenML entities.""" -from typing import Dict, Optional, Set, Union, overload +from typing import Dict, Optional, Union, overload from uuid import UUID from zenml.client import Client From 14dfdea4ca121ff900e00c6671839fce662bdbd8 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 8 Nov 2024 14:55:24 +0100 Subject: [PATCH 042/124] changes for step metadata --- src/zenml/__init__.py | 1 + .../pipelines/test_pipeline_context.py | 2 +- .../integration/functional/steps/test_utils.py | 18 +++++++++--------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/zenml/__init__.py b/src/zenml/__init__.py index 01f8bff1d92..a20957a33a3 100644 --- a/src/zenml/__init__.py +++ b/src/zenml/__init__.py @@ -57,6 +57,7 @@ "get_pipeline_context", "get_step_context", "load_artifact", + "log_metadata", "log_artifact_metadata", "log_model_metadata", "log_step_metadata", diff --git a/tests/integration/functional/pipelines/test_pipeline_context.py b/tests/integration/functional/pipelines/test_pipeline_context.py index e6d3d2d634f..9895e43d142 100644 --- a/tests/integration/functional/pipelines/test_pipeline_context.py +++ b/tests/integration/functional/pipelines/test_pipeline_context.py @@ -10,7 +10,7 @@ pipeline, step, ) -from zenml.artifacts.utils import log_metadata +from zenml import log_metadata from zenml.client import Client diff --git a/tests/integration/functional/steps/test_utils.py b/tests/integration/functional/steps/test_utils.py index 7bdff4867e9..4b1e690ad18 100644 --- a/tests/integration/functional/steps/test_utils.py +++ b/tests/integration/functional/steps/test_utils.py @@ -15,7 +15,7 @@ """Tests for utility functions and classes to run ZenML steps.""" from zenml import pipeline, step -from zenml.steps.utils import log_step_metadata +from zenml import log_metadata @step @@ -31,11 +31,11 @@ def step_metadata_logging_step_inside_run() -> str: "description": "Aria is great!", "metrics": {"accuracy": 0.9}, } - log_step_metadata(metadata=step_metadata) + log_metadata(metadata=step_metadata) return "42" -def test_log_step_metadata_within_step(clean_client): +def test_log_metadata_within_step(clean_client): """Test logging step metadata for the latest run.""" @pipeline @@ -54,7 +54,7 @@ def step_metadata_logging_pipeline(): assert run_metadata["metrics"] == {"accuracy": 0.9} -def test_log_step_metadata_using_latest_run(clean_client): +def test_log_metadata_using_latest_run(clean_client): """Test logging step metadata for the latest run.""" @pipeline @@ -74,10 +74,10 @@ def step_metadata_logging_pipeline(): "description": "Axl is great!", "metrics": {"accuracy": 0.9}, } - log_step_metadata( + log_metadata( metadata=step_metadata, step_name="step_metadata_logging_step", - pipeline_name_id_or_prefix="step_metadata_logging_pipeline", + run_id_name_or_prefix="step_metadata_logging_pipeline", ) run_after_log = step_metadata_logging_pipeline.model.last_run run_metadata_after_log = run_after_log.steps[ @@ -89,7 +89,7 @@ def step_metadata_logging_pipeline(): assert run_metadata_after_log["metrics"] == {"accuracy": 0.9} -def test_log_step_metadata_using_specific_params(clean_client): +def test_log_metadata_using_specific_params(clean_client): """Test logging step metadata for a specific step.""" @pipeline @@ -114,10 +114,10 @@ def step_metadata_logging_pipeline(): "description": "Blupus is great!", "metrics": {"accuracy": 0.9}, } - log_step_metadata( + log_metadata( metadata=step_metadata, step_name="step_metadata_logging_step", - run_id=step_run_id, + run_id_name_or_prefix=step_run_id, ) run_after_log = step_metadata_logging_pipeline.model.last_run run_metadata_after_log = run_after_log.steps[ From d89358dcd48778cfc5177e7c6288af25a1547cab Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 8 Nov 2024 15:03:12 +0100 Subject: [PATCH 043/124] more test adjustments --- .../functional/model/test_model_version.py | 24 +++++++++++-------- .../pipelines/test_pipeline_context.py | 18 ++++++++------ .../functional/steps/test_utils.py | 3 +-- tests/integration/functional/test_client.py | 11 +++++---- 4 files changed, 32 insertions(+), 24 deletions(-) diff --git a/tests/integration/functional/model/test_model_version.py b/tests/integration/functional/model/test_model_version.py index b8cbf95e738..cdf98ac9301 100644 --- a/tests/integration/functional/model/test_model_version.py +++ b/tests/integration/functional/model/test_model_version.py @@ -19,12 +19,12 @@ from typing_extensions import Annotated from tests.integration.functional.utils import random_str -from zenml import get_step_context, pipeline, step +from zenml import get_step_context, log_metadata, pipeline, step from zenml.artifacts.utils import save_artifact from zenml.client import Client from zenml.enums import ModelStages from zenml.model.model import Model -from zenml.model.utils import link_artifact_to_model, log_model_metadata +from zenml.model.utils import link_artifact_to_model from zenml.models import TagRequest @@ -107,10 +107,10 @@ def __exit__(self, exc_type, exc_value, exc_traceback): @step def step_metadata_logging_functional(mdl_name: str): """Functional logging using implicit Model from context.""" - log_model_metadata({"foo": "bar"}) + log_metadata({"foo": "bar"}) assert get_step_context().model.run_metadata["foo"] == "bar" - log_model_metadata( - {"foo": "bar"}, model_name=mdl_name, model_version="other" + log_metadata( + metadata={"foo": "bar"}, model_name=mdl_name, model_version="other" ) @@ -409,18 +409,22 @@ def test_metadata_logging_functional(self): ) mv._get_or_create_model_version() - log_model_metadata( - {"foo": "bar"}, model_name=mv.name, model_version=mv.number + log_metadata( + metadata={"foo": "bar"}, + model_name=mv.name, + model_version=str(mv.number), ) assert len(mv.run_metadata) == 1 assert mv.run_metadata["foo"] == "bar" with pytest.raises(ValueError): - log_model_metadata({"foo": "bar"}) + log_metadata({"foo": "bar"}) - log_model_metadata( - {"bar": "foo"}, model_name=mv.name, model_version="latest" + log_metadata( + metadata={"bar": "foo"}, + model_name=mv.name, + model_version="latest", ) assert len(mv.run_metadata) == 2 diff --git a/tests/integration/functional/pipelines/test_pipeline_context.py b/tests/integration/functional/pipelines/test_pipeline_context.py index 9895e43d142..7f044c09f12 100644 --- a/tests/integration/functional/pipelines/test_pipeline_context.py +++ b/tests/integration/functional/pipelines/test_pipeline_context.py @@ -6,11 +6,10 @@ Model, get_pipeline_context, get_step_context, - log_model_metadata, + log_metadata, pipeline, step, ) -from zenml import log_metadata from zenml.client import Client @@ -101,17 +100,22 @@ def test_that_argument_as_get_artifact_of_model_in_pipeline_context_fails_if_not @step def producer() -> Annotated[str, "bar"]: """Produce artifact with metadata and attach metadata to model version.""" - ver = get_step_context().model.version - log_model_metadata(metadata={"foobar": "model_meta_" + ver}) - log_metadata(metadata={"foobar": "artifact_meta_" + ver}) - return "artifact_data_" + ver + model = get_step_context().model + + log_metadata( + metadata={"foobar": "model_meta_" + model.ver}, + model_name=model.name, + model_version=model.version, + ) + log_metadata(metadata={"foobar": "artifact_meta_" + model.ver}) + return "artifact_data_" + model.ver @step def asserter(artifact: str, artifact_metadata: str, model_metadata: str): """Assert that passed in values are loaded in lazy mode. - They do not exists before actual run of the pipeline. + They do not exist before actual run of the pipeline. """ ver = get_step_context().model.version assert artifact == "artifact_data_" + ver diff --git a/tests/integration/functional/steps/test_utils.py b/tests/integration/functional/steps/test_utils.py index 4b1e690ad18..8ba0fdfebc2 100644 --- a/tests/integration/functional/steps/test_utils.py +++ b/tests/integration/functional/steps/test_utils.py @@ -14,8 +14,7 @@ """Tests for utility functions and classes to run ZenML steps.""" -from zenml import pipeline, step -from zenml import log_metadata +from zenml import log_metadata, pipeline, step @step diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index e1d2f530075..9d5ea6b5031 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -33,7 +33,6 @@ get_pipeline_context, get_step_context, log_metadata, - log_model_metadata, pipeline, save_artifact, step, @@ -972,12 +971,14 @@ def lazy_producer_test_artifact() -> Annotated[str, "new_one"]: client = Client() - log_model_metadata( + model = get_step_context().model + + log_metadata( metadata={"some_meta": "meta_new_one"}, + model_name=model.name, + model_version=model.model_version, ) - model = get_step_context().model - mv = client.create_model_version( model_name_or_id=model.name, name="model_version2", @@ -1137,7 +1138,7 @@ def dummy(): artifact_name="preexisting", artifact_version="1.2.3", ) - log_model_metadata( + log_metadata( metadata={"some_meta": "meta_preexisting"}, model_name="aria", model_version="model_version", From 7d903052b7484e1f5656aa8e54485eba57f902af Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 8 Nov 2024 15:11:49 +0100 Subject: [PATCH 044/124] testing unit tests --- tests/unit/models/test_filter_models.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/unit/models/test_filter_models.py b/tests/unit/models/test_filter_models.py index 779d6120b01..46b711bb7cc 100644 --- a/tests/unit/models/test_filter_models.py +++ b/tests/unit/models/test_filter_models.py @@ -182,7 +182,7 @@ def test_datetime_filter_model(): filter_class=DatetimeFilter, filter_value=filter_value, expected_value=expected_value, - ignore_operators=[GenericFilterOps.IN], + ignore_operators=[GenericFilterOps.IN, GenericFilterOps.ONEOF], ) @@ -231,6 +231,7 @@ def test_uuid_filter_model(): filter_class=UUIDFilter, filter_value=filter_value, expected_value=str(filter_value).replace("-", ""), + ignore_operators=[GenericFilterOps.ONEOF], ) @@ -245,7 +246,10 @@ def test_uuid_filter_model_succeeds_for_invalid_uuid_on_non_equality(): """Test filtering with other UUID operations is possible with non-UUIDs.""" filter_value = "a92k34" for filter_op in UUIDFilter.ALLOWED_OPS: - if filter_op == GenericFilterOps.EQUALS: + if ( + filter_op == GenericFilterOps.EQUALS + or filter_op == GenericFilterOps.ONEOF + ): continue filter_model = SomeFilterModel( uuid_field=f"{filter_op}:{filter_value}" @@ -264,4 +268,5 @@ def test_string_filter_model(): filter_field="str_field", filter_class=StrFilter, filter_value="a_random_string", + ignore_operators=[GenericFilterOps.ONEOF], ) From b0609871f3b34f9b8e9713e4706183330c6ab95c Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 8 Nov 2024 15:19:39 +0100 Subject: [PATCH 045/124] linting --- src/zenml/client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/zenml/client.py b/src/zenml/client.py index cd9fa28446e..e1ab579a272 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -3834,6 +3834,7 @@ def list_pipeline_runs( templatable: If the runs should be templatable or not. tag: Tag to filter by. user: The name/ID of the user to filter by. + run_metadata: The run_metadata of the run to filter by. pipeline: The name/ID of the pipeline to filter by. code_repository: Filter by code repository name/ID. model: Filter by model name/ID. From 43a7034460b88086d1e3a7c4fd58144252a4f8f2 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 8 Nov 2024 15:42:53 +0100 Subject: [PATCH 046/124] fixing more tests --- src/zenml/client.py | 2 +- src/zenml/utils/metadata_utils.py | 25 +++++++++++++------ .../functional/artifacts/test_utils.py | 5 ++-- .../functional/steps/test_utils.py | 3 +-- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index e1ab579a272..d1acc0e033b 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -4195,7 +4195,7 @@ def get_artifact_version( ), ) except RuntimeError: - pass # Cannot link to step run if called outside of a step + pass # Cannot link to step run if called outside a step return artifact def list_artifact_versions( diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 735da38fea8..deb7700c384 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -13,6 +13,7 @@ # permissions and limitations under the License. """Utility functions to handle metadata for ZenML entities.""" +import contextlib from typing import Dict, Optional, Union, overload from uuid import UUID @@ -216,11 +217,12 @@ def log_metadata( resource_type=MetadataResourceTypes.MODEL_VERSION, ) - # If the user provides an artifact name, there are two possibilities. If + # If the user provides an artifact name, there are three possibilities. If # an artifact version is also provided with the name, we use both to fetch # the artifact version and use it to log the metadata. If no version is - # provided, we make sure that the call is happening within a step, otherwise - # we fail. + # provided, if the function is called within a step we search the artifacts + # of the step if not we fetch the latest version and attach the metadata + # to the latest version. elif artifact_name is not None: if artifact_version: artifact_version_model = client.get_artifact_version( @@ -232,15 +234,22 @@ def log_metadata( resource_type=MetadataResourceTypes.ARTIFACT_VERSION, ) else: - try: + step_context = None + with contextlib.suppress(RuntimeError): step_context = get_step_context() + + if step_context: step_context.add_output_metadata( metadata=metadata, output_name=artifact_name ) - except RuntimeError: - raise ValueError( - "You are calling 'log_metadata(artifact_name='...') " - "without specifying a version outside of a step execution." + else: + artifact_version_model = client.get_artifact_version( + name_id_or_prefix=artifact_name + ) + client.create_run_metadata( + metadata=metadata, + resource_id=artifact_version_model.id, + resource_type=MetadataResourceTypes.ARTIFACT_VERSION, ) # If the user directly provides an artifact_version_id, we use the client to diff --git a/tests/integration/functional/artifacts/test_utils.py b/tests/integration/functional/artifacts/test_utils.py index 938cd7abba1..b0c2493ecdb 100644 --- a/tests/integration/functional/artifacts/test_utils.py +++ b/tests/integration/functional/artifacts/test_utils.py @@ -17,6 +17,7 @@ log_metadata, pipeline, save_artifact, +log_artifact_metadata, step, ) from zenml.artifacts.utils import register_artifact @@ -185,7 +186,7 @@ def artifact_metadata_logging_step() -> str: "description": "Aria is great!", "metrics": {"accuracy": 0.9}, } - log_metadata(metadata=output_metadata) + log_artifact_metadata(metadata=output_metadata) return "42" @@ -247,7 +248,7 @@ def wrong_artifact_multi_output_metadata_logging_step() -> ( "description": "Axl is great!", "metrics": {"accuracy": 0.9}, } - log_metadata(output_metadata) + log_artifact_metadata(output_metadata) return "42", 42 diff --git a/tests/integration/functional/steps/test_utils.py b/tests/integration/functional/steps/test_utils.py index 8ba0fdfebc2..ed983547a2f 100644 --- a/tests/integration/functional/steps/test_utils.py +++ b/tests/integration/functional/steps/test_utils.py @@ -115,8 +115,7 @@ def step_metadata_logging_pipeline(): } log_metadata( metadata=step_metadata, - step_name="step_metadata_logging_step", - run_id_name_or_prefix=step_run_id, + step_id=step_run_id, ) run_after_log = step_metadata_logging_pipeline.model.last_run run_metadata_after_log = run_after_log.steps[ From 28ecdc1ddc69f7d62633daf49dd111f64c83c611 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 8 Nov 2024 15:50:23 +0100 Subject: [PATCH 047/124] fixing more tests --- src/zenml/client.py | 2 ++ .../functional/pipelines/test_pipeline_context.py | 6 +++--- tests/integration/functional/steps/test_step_context.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index d1acc0e033b..e51ebdf17c7 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -4223,6 +4223,7 @@ def list_artifact_versions( user: Optional[Union[UUID, str]] = None, model: Optional[Union[UUID, str]] = None, pipeline_run: Optional[Union[UUID, str]] = None, + run_metadata: Optional[Dict[str,str]] = None, tag: Optional[str] = None, hydrate: bool = False, ) -> Page[ArtifactVersionResponse]: @@ -4254,6 +4255,7 @@ def list_artifact_versions( user: Filter by user name or ID. model: Filter by model name or ID. pipeline_run: Filter by pipeline run name or ID. + run_metadata: Filter by run metadata. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. diff --git a/tests/integration/functional/pipelines/test_pipeline_context.py b/tests/integration/functional/pipelines/test_pipeline_context.py index 7f044c09f12..81355b8cab7 100644 --- a/tests/integration/functional/pipelines/test_pipeline_context.py +++ b/tests/integration/functional/pipelines/test_pipeline_context.py @@ -103,12 +103,12 @@ def producer() -> Annotated[str, "bar"]: model = get_step_context().model log_metadata( - metadata={"foobar": "model_meta_" + model.ver}, + metadata={"foobar": "model_meta_" + model.version}, model_name=model.name, model_version=model.version, ) - log_metadata(metadata={"foobar": "artifact_meta_" + model.ver}) - return "artifact_data_" + model.ver + log_metadata(metadata={"foobar": "artifact_meta_" + model.version}) + return "artifact_data_" + model.version @step diff --git a/tests/integration/functional/steps/test_step_context.py b/tests/integration/functional/steps/test_step_context.py index 0291aa68146..f2b85868a89 100644 --- a/tests/integration/functional/steps/test_step_context.py +++ b/tests/integration/functional/steps/test_step_context.py @@ -92,7 +92,7 @@ def _simple_step_pipeline(): @step def output_metadata_logging_step() -> Annotated[int, "my_output"]: - log_metadata(metadata={"some_key": "some_value"}) + log_metadata(metadata={"some_key": "some_value"}, artifact_name="my_output") return 42 From e0c5e4f1a3ea2c67ec8493ca6dc2e79469463a1b Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 8 Nov 2024 16:36:24 +0100 Subject: [PATCH 048/124] more test fixes --- src/zenml/client.py | 2 +- src/zenml/models/v2/base/filter.py | 16 ++++++++++------ .../functional/artifacts/test_utils.py | 2 +- .../pipelines/test_pipeline_context.py | 5 ++++- .../functional/steps/test_step_context.py | 4 +++- tests/integration/functional/test_client.py | 2 +- 6 files changed, 20 insertions(+), 11 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index e51ebdf17c7..62da1408b54 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -4223,7 +4223,7 @@ def list_artifact_versions( user: Optional[Union[UUID, str]] = None, model: Optional[Union[UUID, str]] = None, pipeline_run: Optional[Union[UUID, str]] = None, - run_metadata: Optional[Dict[str,str]] = None, + run_metadata: Optional[Dict[str, str]] = None, tag: Optional[str] = None, hydrate: bool = False, ) -> Page[ArtifactVersionResponse]: diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index 43792314c5c..109e8c213d6 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -187,6 +187,9 @@ def generate_query_conditions_from_column(self, column: Any) -> Any: Returns: A list of query conditions. + + Raises: + ValueError: the comparison of the column to a numeric value fails. """ if self.operation == GenericFilterOps.CONTAINS: return column.like(f"%{self.value}%") @@ -227,7 +230,8 @@ def generate_query_conditions_from_column(self, column: Any) -> Any: ) except Exception as e: raise ValueError( - f"Failed to compare the column to the numeric value: {e}" + f"Failed to compare the column '{column}' to the " + f"value '{self.value}' (must be numeric): {e}" ) return column == self.value @@ -630,6 +634,10 @@ def _resolve_operator(value: Any) -> Tuple[Any, GenericFilterOps]: Returns: A tuple of the filter value and the operator. + + Raises: + ValueError: when we try to use the `oneof` operator with the wrong + value. """ operator = GenericFilterOps.EQUALS # Default operator if isinstance(value, str): @@ -645,11 +653,7 @@ def _resolve_operator(value: Any) -> Tuple[Any, GenericFilterOps]: try: value = json.loads(value) if not isinstance(value, list): - raise ValueError( - "When you are using the 'oneof:' filtering " - "make sure that the provided value is a json " - "formatted list." - ) + raise ValueError except ValueError: raise ValueError( "When you are using the 'oneof:' filtering " diff --git a/tests/integration/functional/artifacts/test_utils.py b/tests/integration/functional/artifacts/test_utils.py index b0c2493ecdb..5c091e61fb0 100644 --- a/tests/integration/functional/artifacts/test_utils.py +++ b/tests/integration/functional/artifacts/test_utils.py @@ -14,10 +14,10 @@ from zenml import ( load_artifact, + log_artifact_metadata, log_metadata, pipeline, save_artifact, -log_artifact_metadata, step, ) from zenml.artifacts.utils import register_artifact diff --git a/tests/integration/functional/pipelines/test_pipeline_context.py b/tests/integration/functional/pipelines/test_pipeline_context.py index 81355b8cab7..70e7608f7a8 100644 --- a/tests/integration/functional/pipelines/test_pipeline_context.py +++ b/tests/integration/functional/pipelines/test_pipeline_context.py @@ -107,7 +107,10 @@ def producer() -> Annotated[str, "bar"]: model_name=model.name, model_version=model.version, ) - log_metadata(metadata={"foobar": "artifact_meta_" + model.version}) + log_metadata( + metadata={"foobar": "artifact_meta_" + model.version}, + artifact_name="bar", + ) return "artifact_data_" + model.version diff --git a/tests/integration/functional/steps/test_step_context.py b/tests/integration/functional/steps/test_step_context.py index f2b85868a89..d520cfd83a4 100644 --- a/tests/integration/functional/steps/test_step_context.py +++ b/tests/integration/functional/steps/test_step_context.py @@ -92,7 +92,9 @@ def _simple_step_pipeline(): @step def output_metadata_logging_step() -> Annotated[int, "my_output"]: - log_metadata(metadata={"some_key": "some_value"}, artifact_name="my_output") + log_metadata( + metadata={"some_key": "some_value"}, artifact_name="my_output" + ) return 42 diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index 9d5ea6b5031..dde9c8e8767 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -976,7 +976,7 @@ def lazy_producer_test_artifact() -> Annotated[str, "new_one"]: log_metadata( metadata={"some_meta": "meta_new_one"}, model_name=model.name, - model_version=model.model_version, + model_version=model.version, ) mv = client.create_model_version( From 6edc16e75111e82859fdbb9ab09f96c6920bb5d1 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Mon, 11 Nov 2024 10:16:52 +0100 Subject: [PATCH 049/124] fixing the test --- tests/integration/functional/test_client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index dde9c8e8767..bd9583a8a1d 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -967,7 +967,9 @@ def lazy_producer_test_artifact() -> Annotated[str, "new_one"]: """Produce artifact with metadata.""" from zenml.client import Client - log_metadata(metadata={"some_meta": "meta_new_one"}) + log_metadata( + metadata={"some_meta": "meta_new_one"}, artifact_name="new_one" + ) client = Client() From 030d530b4e6e32e27fcc840525be0f0dedf1d120 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Mon, 11 Nov 2024 11:06:28 +0100 Subject: [PATCH 050/124] fixing per comments --- src/zenml/utils/metadata_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index deb7700c384..47bd4f06e38 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -122,7 +122,6 @@ def log_metadata( ValueError: If no identifiers are provided and the function is not called from within a step. """ - # Initialize the client client = Client() # If a step name is provided, we need a run_id_name_or_prefix and will log @@ -155,12 +154,9 @@ def log_metadata( # as well. elif step_id is not None: step_model = client.get_run_step(step_run_id=step_id) - run_model = client.get_pipeline_run( - name_id_or_prefix=step_model.pipeline_run_id - ) client.create_run_metadata( metadata=metadata, - resource_id=run_model.id, + resource_id=step_model.pipeline_run_id, resource_type=MetadataResourceTypes.PIPELINE_RUN, ) client.create_run_metadata( From 929fba45f131fc8895eafe7bd9ac1443b2642ae2 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Mon, 11 Nov 2024 11:49:38 +0100 Subject: [PATCH 051/124] added validation, constant error message --- src/zenml/models/v2/base/filter.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index 109e8c213d6..279d9861aa5 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -64,6 +64,11 @@ AnyQuery = TypeVar("AnyQuery", bound=Any) +ONEOF_ERROR = ( + "When you are using the 'oneof:' filtering make sure that the " + "provided value is a json formatted list." +) + class Filter(BaseModel, ABC): """Filter for all fields. @@ -179,6 +184,13 @@ class StrFilter(Filter): GenericFilterOps.LTE, ] + @model_validator(mode="after") + def check_value_if_operation_oneof(self) -> "StrFilter": + if self.operation == GenericFilterOps.ONEOF: + if not isinstance(self.value, list): + raise ValueError(ONEOF_ERROR) + return self + def generate_query_conditions_from_column(self, column: Any) -> Any: """Generate query conditions for a string column. @@ -655,11 +667,7 @@ def _resolve_operator(value: Any) -> Tuple[Any, GenericFilterOps]: if not isinstance(value, list): raise ValueError except ValueError: - raise ValueError( - "When you are using the 'oneof:' filtering " - "make sure that the provided value is a json " - "formatted list." - ) + raise ValueError(ONEOF_ERROR) return value, operator @@ -1098,10 +1106,7 @@ def _define_uuid_filter( # For equality checks, ensure that the value is a valid UUID. if operator == GenericFilterOps.ONEOF and not isinstance(value, list): - raise ValueError( - "If you are using `oneof:` as a filtering op, the value needs " - "to be a json formatted list string." - ) + raise ValueError(ONEOF_ERROR) # Generate the filter. uuid_filter = UUIDFilter( From c1bcb003b2839a4b7f6d4c60edecc76659920706 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 12 Nov 2024 09:07:16 +0100 Subject: [PATCH 052/124] linting --- src/zenml/models/v2/base/filter.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index 279d9861aa5..1c4d2cccfb5 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -186,6 +186,14 @@ class StrFilter(Filter): @model_validator(mode="after") def check_value_if_operation_oneof(self) -> "StrFilter": + """Validator to check if value is a list if oneof operation is used. + + Raises: + ValueError: If the value is not a list + + Returns: + self + """ if self.operation == GenericFilterOps.ONEOF: if not isinstance(self.value, list): raise ValueError(ONEOF_ERROR) From 57ba4f9a0026e0837e7b752b0b0ba9cd4c2d780e Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 12 Nov 2024 15:02:42 +0100 Subject: [PATCH 053/124] new changes --- src/zenml/artifacts/utils.py | 3 +- src/zenml/client.py | 11 +- src/zenml/model/model.py | 3 +- src/zenml/models/v2/core/artifact_version.py | 9 +- src/zenml/models/v2/core/model_version.py | 9 +- src/zenml/models/v2/core/pipeline_run.py | 9 +- src/zenml/models/v2/core/run_metadata.py | 9 +- src/zenml/models/v2/core/step_run.py | 9 +- src/zenml/orchestrators/publish_utils.py | 6 +- src/zenml/steps/utils.py | 3 +- src/zenml/utils/metadata_utils.py | 138 ++++++++---------- .../routers/workspaces_endpoints.py | 33 ++--- .../cc269488e5a9_separate_run_metadata.py | 68 +++++++++ src/zenml/zen_stores/schemas/__init__.py | 6 +- .../zen_stores/schemas/artifact_schemas.py | 11 +- src/zenml/zen_stores/schemas/model_schemas.py | 11 +- .../schemas/pipeline_run_schemas.py | 12 +- .../schemas/run_metadata_schemas.py | 50 ++++++- .../zen_stores/schemas/step_run_schemas.py | 12 +- src/zenml/zen_stores/sql_zen_store.py | 38 +++-- tests/integration/functional/test_client.py | 11 +- .../functional/zen_stores/test_zen_store.py | 3 +- 22 files changed, 288 insertions(+), 176 deletions(-) create mode 100644 src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py diff --git a/src/zenml/artifacts/utils.py b/src/zenml/artifacts/utils.py index e18485d42ab..76fa22eb0e1 100644 --- a/src/zenml/artifacts/utils.py +++ b/src/zenml/artifacts/utils.py @@ -440,8 +440,7 @@ def log_artifact_metadata( response = client.get_artifact_version(artifact_name, artifact_version) client.create_run_metadata( metadata=metadata, - resource_id=response.id, - resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + resources=[(response.id, MetadataResourceTypes.ARTIFACT_VERSION)], ) else: diff --git a/src/zenml/client.py b/src/zenml/client.py index 154700bce8c..684e3f33ac4 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -4435,17 +4435,14 @@ def _delete_artifact_from_artifact_store( def create_run_metadata( self, metadata: Dict[str, "MetadataType"], - resource_id: UUID, - resource_type: MetadataResourceTypes, + resources: List[Tuple[UUID, MetadataResourceTypes]], stack_component_id: Optional[UUID] = None, ) -> None: """Create run metadata. Args: metadata: The metadata to create as a dictionary of key-value pairs. - resource_id: The ID of the resource for which the - metadata was produced. - resource_type: The type of the resource for which the + resources: The ID and type of the resources for that the metadata was produced. stack_component_id: The ID of the stack component that produced the metadata. @@ -4480,14 +4477,12 @@ def create_run_metadata( run_metadata = RunMetadataRequest( workspace=self.active_workspace.id, user=self.active_user.id, - resource_id=resource_id, - resource_type=resource_type, + resources=resources, stack_component_id=stack_component_id, values=values, types=types, ) self.zen_store.create_run_metadata(run_metadata) - return None # -------------------------------- Secrets --------------------------------- diff --git a/src/zenml/model/model.py b/src/zenml/model/model.py index 05c0045ca66..39db25eb657 100644 --- a/src/zenml/model/model.py +++ b/src/zenml/model/model.py @@ -341,8 +341,7 @@ def log_metadata( response = self._get_or_create_model_version() Client().create_run_metadata( metadata=metadata, - resource_id=response.id, - resource_type=MetadataResourceTypes.MODEL_VERSION, + resources=[(response.id, MetadataResourceTypes.MODEL_VERSION)], ) @property diff --git a/src/zenml/models/v2/core/artifact_version.py b/src/zenml/models/v2/core/artifact_version.py index d26b3bceef4..fa328c11c2d 100644 --- a/src/zenml/models/v2/core/artifact_version.py +++ b/src/zenml/models/v2/core/artifact_version.py @@ -569,6 +569,7 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: ModelSchema, ModelVersionArtifactSchema, PipelineRunSchema, + RunMetadataResourceLinkSchema, RunMetadataSchema, StepRunInputArtifactSchema, StepRunOutputArtifactSchema, @@ -656,10 +657,12 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: for key, value in self.run_metadata.items(): additional_filter = and_( - RunMetadataSchema.resource_id == ArtifactVersionSchema.id, - RunMetadataSchema.resource_type + RunMetadataResourceLinkSchema.resource_id + == ArtifactVersionSchema.id, + RunMetadataResourceLinkSchema.resource_type == MetadataResourceTypes.ARTIFACT_VERSION, - RunMetadataSchema.key == key, + RunMetadataResourceLinkSchema.run_metadata_id + == RunMetadataSchema.id, self.generate_custom_query_conditions_for_column( value=value, table=RunMetadataSchema, diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index f2e3a7aa911..02cb45ed1e5 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -656,6 +656,7 @@ def get_custom_filters( from zenml.zen_stores.schemas import ( ModelVersionSchema, + RunMetadataResourceLinkSchema, RunMetadataSchema, UserSchema, ) @@ -676,10 +677,12 @@ def get_custom_filters( for key, value in self.run_metadata.items(): additional_filter = and_( - RunMetadataSchema.resource_id == ModelVersionSchema.id, - RunMetadataSchema.resource_type + RunMetadataResourceLinkSchema.resource_id + == ModelVersionSchema.id, + RunMetadataResourceLinkSchema.resource_type == MetadataResourceTypes.MODEL_VERSION, - RunMetadataSchema.key == key, + RunMetadataResourceLinkSchema.run_metadata_id + == RunMetadataSchema.id, self.generate_custom_query_conditions_for_column( value=value, table=RunMetadataSchema, diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index 8468c105bee..60e653ac843 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -722,6 +722,7 @@ def get_custom_filters( PipelineDeploymentSchema, PipelineRunSchema, PipelineSchema, + RunMetadataResourceLinkSchema, RunMetadataSchema, ScheduleSchema, StackComponentSchema, @@ -897,10 +898,12 @@ def get_custom_filters( for key, value in self.run_metadata.items(): additional_filter = and_( - RunMetadataSchema.resource_id == PipelineRunSchema.id, - RunMetadataSchema.resource_type + RunMetadataResourceLinkSchema.resource_id + == PipelineRunSchema.id, + RunMetadataResourceLinkSchema.resource_type == MetadataResourceTypes.PIPELINE_RUN, - RunMetadataSchema.key == key, + RunMetadataResourceLinkSchema.run_metadata_id + == RunMetadataSchema.id, self.generate_custom_query_conditions_for_column( value=value, table=RunMetadataSchema, diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index c4a2ef8e678..da395ab0e6c 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing run metadata.""" -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple from uuid import UUID from pydantic import Field @@ -30,11 +30,8 @@ class RunMetadataRequest(WorkspaceScopedRequest): """Request model for run metadata.""" - resource_id: UUID = Field( - title="The ID of the resource that this metadata belongs to.", - ) - resource_type: MetadataResourceTypes = Field( - title="The type of the resource that this metadata belongs to.", + resources: List[Tuple[UUID, MetadataResourceTypes]] = Field( + title="The list of resources that this metadata belongs to." ) stack_component_id: Optional[UUID] = Field( title="The ID of the stack component that this metadata belongs to." diff --git a/src/zenml/models/v2/core/step_run.py b/src/zenml/models/v2/core/step_run.py index 7052a1b42d7..29557c1a554 100644 --- a/src/zenml/models/v2/core/step_run.py +++ b/src/zenml/models/v2/core/step_run.py @@ -594,6 +594,7 @@ def get_custom_filters( from zenml.zen_stores.schemas import ( ModelSchema, ModelVersionSchema, + RunMetadataResourceLinkSchema, RunMetadataSchema, StepRunSchema, ) @@ -612,10 +613,12 @@ def get_custom_filters( for key, value in self.run_metadata.items(): additional_filter = and_( - RunMetadataSchema.resource_id == StepRunSchema.id, - RunMetadataSchema.resource_type + RunMetadataResourceLinkSchema.resource_id + == StepRunSchema.id, + RunMetadataResourceLinkSchema.resource_type == MetadataResourceTypes.STEP_RUN, - RunMetadataSchema.key == key, + RunMetadataResourceLinkSchema.run_metadata_id + == RunMetadataSchema.id, self.generate_custom_query_conditions_for_column( value=value, table=RunMetadataSchema, diff --git a/src/zenml/orchestrators/publish_utils.py b/src/zenml/orchestrators/publish_utils.py index 1b8f168517a..a6d864aae32 100644 --- a/src/zenml/orchestrators/publish_utils.py +++ b/src/zenml/orchestrators/publish_utils.py @@ -129,8 +129,7 @@ def publish_pipeline_run_metadata( for stack_component_id, metadata in pipeline_run_metadata.items(): client.create_run_metadata( metadata=metadata, - resource_id=pipeline_run_id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, + resources=[(pipeline_run_id, MetadataResourceTypes.PIPELINE_RUN)], stack_component_id=stack_component_id, ) @@ -150,7 +149,6 @@ def publish_step_run_metadata( for stack_component_id, metadata in step_run_metadata.items(): client.create_run_metadata( metadata=metadata, - resource_id=step_run_id, - resource_type=MetadataResourceTypes.STEP_RUN, + resources=[(step_run_id, MetadataResourceTypes.STEP_RUN)], stack_component_id=stack_component_id, ) diff --git a/src/zenml/steps/utils.py b/src/zenml/steps/utils.py index e237d12f9ff..4cdce1d85e3 100644 --- a/src/zenml/steps/utils.py +++ b/src/zenml/steps/utils.py @@ -477,8 +477,7 @@ def log_step_metadata( step_run_id = pipeline_run.steps[step_name].id client.create_run_metadata( metadata=metadata, - resource_id=step_run_id, - resource_type=MetadataResourceTypes.STEP_RUN, + resources=[(step_run_id, MetadataResourceTypes.STEP_RUN)], ) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 47bd4f06e38..79187941e1f 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -127,67 +127,46 @@ def log_metadata( # If a step name is provided, we need a run_id_name_or_prefix and will log # metadata for the steps pipeline and model accordingly. if step_name is not None and run_id_name_or_prefix is not None: - run_model = client.get_pipeline_run( - name_id_or_prefix=run_id_name_or_prefix - ) - step_model = run_model.steps[step_name] + run = client.get_pipeline_run(run_id_name_or_prefix) + step = run.steps[step_name] - client.create_run_metadata( - metadata=metadata, - resource_id=run_model.id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, - ) - client.create_run_metadata( - metadata=metadata, - resource_id=step_model.id, - resource_type=MetadataResourceTypes.STEP_RUN, - ) - if step_model.model_version: - client.create_run_metadata( - metadata=metadata, - resource_id=step_model.model_version.id, - resource_type=MetadataResourceTypes.MODEL_VERSION, + resources = [ + (run.id, MetadataResourceTypes.PIPELINE_RUN), + (step.id, MetadataResourceTypes.STEP_RUN), + ] + if step.model_version: + resources.append( + (step.model_version.id, MetadataResourceTypes.MODEL_VERSION) ) - + client.create_run_metadata(metadata=metadata, resources=resources) # If a step is identified by id, fetch it directly through the client, # follow a similar procedure and log metadata for its pipeline and model # as well. elif step_id is not None: - step_model = client.get_run_step(step_run_id=step_id) - client.create_run_metadata( - metadata=metadata, - resource_id=step_model.pipeline_run_id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, - ) - client.create_run_metadata( - metadata=metadata, - resource_id=step_model.id, - resource_type=MetadataResourceTypes.STEP_RUN, - ) - if step_model.model_version: - client.create_run_metadata( - metadata=metadata, - resource_id=step_model.model_version.id, - resource_type=MetadataResourceTypes.MODEL_VERSION, + step = client.get_run_step(step_id) + + resources = [ + (step.pipeline_run_id, MetadataResourceTypes.PIPELINE_RUN), + (step.id, MetadataResourceTypes.STEP_RUN), + ] + if step.model_version: + resources.append( + (step.model_version.id, MetadataResourceTypes.MODEL_VERSION) ) + client.create_run_metadata(metadata=metadata, resources=resources) # If a pipeline run id is identified, we need to log metadata to it and its # model as well. elif run_id_name_or_prefix is not None: - run_model = client.get_pipeline_run( - name_id_or_prefix=run_id_name_or_prefix - ) - client.create_run_metadata( - metadata=metadata, - resource_id=run_model.id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, - ) - if run_model.model_version: - client.create_run_metadata( - metadata=metadata, - resource_id=run_model.model_version.id, - resource_type=MetadataResourceTypes.MODEL_VERSION, + run = client.get_pipeline_run(run_id_name_or_prefix) + + resources = [(run.id, MetadataResourceTypes.PIPELINE_RUN)] + + if run.model_version: + resources.append( + (run.model_version.id, MetadataResourceTypes.MODEL_VERSION) ) + client.create_run_metadata(metadata=metadata, resources=resources) # If the user provides a model name and version, we use to model abstraction # to fetch the model version and attach the corresponding metadata to it. @@ -195,22 +174,20 @@ def log_metadata( from zenml import Model mv = Model(name=model_name, version=model_version) + client.create_run_metadata( metadata=metadata, - resource_id=mv.id, - resource_type=MetadataResourceTypes.MODEL_VERSION, + resources=[(mv.id, MetadataResourceTypes.MODEL_VERSION)], ) # If the user provides a model version id, we use the client to fetch it and # attach the metadata to it. elif model_version_id is not None: - model_version_id = client.get_model_version( - model_version_name_or_number_or_id=model_version_id - ).id client.create_run_metadata( metadata=metadata, - resource_id=model_version_id, - resource_type=MetadataResourceTypes.MODEL_VERSION, + resources=[ + (model_version_id, MetadataResourceTypes.MODEL_VERSION) + ], ) # If the user provides an artifact name, there are three possibilities. If @@ -226,8 +203,12 @@ def log_metadata( ) client.create_run_metadata( metadata=metadata, - resource_id=artifact_version_model.id, - resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + resources=[ + ( + artifact_version_model.id, + MetadataResourceTypes.ARTIFACT_VERSION, + ) + ], ) else: step_context = None @@ -244,20 +225,22 @@ def log_metadata( ) client.create_run_metadata( metadata=metadata, - resource_id=artifact_version_model.id, - resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + resources=[ + ( + artifact_version_model.id, + MetadataResourceTypes.ARTIFACT_VERSION, + ) + ], ) # If the user directly provides an artifact_version_id, we use the client to # fetch is and attach the metadata accordingly. elif artifact_version_id is not None: - artifact_version_model = client.get_artifact_version( - name_id_or_prefix=artifact_version_id, - ) client.create_run_metadata( metadata=metadata, - resource_id=artifact_version_model.id, - resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + resources=[ + (artifact_version_id, MetadataResourceTypes.ARTIFACT_VERSION) + ], ) # If every additional value is None, that means we are calling it bare bones @@ -287,22 +270,21 @@ def log_metadata( "of the step execution, please provide the required " "identifiers." ) + resources = [ + (step_context.step_run.id, MetadataResourceTypes.STEP_RUN), + (step_context.pipeline_run.id, MetadataResourceTypes.PIPELINE_RUN), + ] + if step_context.model_version: + resources.append( + ( + step_context.model_version.id, + MetadataResourceTypes.MODEL_VERSION, + ) + ) client.create_run_metadata( metadata=metadata, - resource_id=step_context.pipeline_run.id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, - ) - client.create_run_metadata( - metadata=metadata, - resource_id=step_context.step_run.id, - resource_type=MetadataResourceTypes.STEP_RUN, + resources=resources, ) - if step_context.model_version: - client.create_run_metadata( - metadata=metadata, - resource_id=step_context.model_version.id, - resource_type=MetadataResourceTypes.MODEL_VERSION, - ) else: raise ValueError( diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index dd33ebee869..297d5a37ca4 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1014,24 +1014,21 @@ def create_run_metadata( "is not supported." ) - if run_metadata.resource_type == MetadataResourceTypes.PIPELINE_RUN: - run = zen_store().get_run(run_metadata.resource_id) - verify_permission_for_model(run, action=Action.UPDATE) - elif run_metadata.resource_type == MetadataResourceTypes.STEP_RUN: - step = zen_store().get_run_step(run_metadata.resource_id) - verify_permission_for_model(step, action=Action.UPDATE) - elif run_metadata.resource_type == MetadataResourceTypes.ARTIFACT_VERSION: - artifact_version = zen_store().get_artifact_version( - run_metadata.resource_id - ) - verify_permission_for_model(artifact_version, action=Action.UPDATE) - elif run_metadata.resource_type == MetadataResourceTypes.MODEL_VERSION: - model_version = zen_store().get_model_version(run_metadata.resource_id) - verify_permission_for_model(model_version, action=Action.UPDATE) - else: - raise RuntimeError( - f"Unknown resource type: {run_metadata.resource_type}" - ) + for resource in run_metadata.resources: + if resource[1] == MetadataResourceTypes.PIPELINE_RUN: + run = zen_store().get_run(resource[0]) + verify_permission_for_model(run, action=Action.UPDATE) + elif resource[1] == MetadataResourceTypes.STEP_RUN: + step = zen_store().get_run_step(resource[0]) + verify_permission_for_model(step, action=Action.UPDATE) + elif resource[1] == MetadataResourceTypes.ARTIFACT_VERSION: + artifact_version = zen_store().get_artifact_version(resource[0]) + verify_permission_for_model(artifact_version, action=Action.UPDATE) + elif resource[1] == MetadataResourceTypes.MODEL_VERSION: + model_version = zen_store().get_model_version(resource[0]) + verify_permission_for_model(model_version, action=Action.UPDATE) + else: + raise RuntimeError(f"Unknown resource type: {resource[1]}") verify_permission( resource_type=ResourceType.RUN_METADATA, action=Action.CREATE diff --git a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py new file mode 100644 index 00000000000..82050042f16 --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py @@ -0,0 +1,68 @@ +"""separate run metadata [cc269488e5a9]. + +Revision ID: cc269488e5a9 +Revises: 904464ea4041 +Create Date: 2024-11-12 09:46:46.587478 + +""" + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "cc269488e5a9" +down_revision = "904464ea4041" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Creates the 'run_metadata_resource_link' table.""" + # Create the `run_metadata_resource_link` table + op.create_table( + "run_metadata_resource_link", + sa.Column("resource_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("resource_type", sa.String(length=255), nullable=False), + sa.Column( + "run_metadata_id", + sa.Integer, + sa.ForeignKey("run_metadata.id", ondelete="CASCADE"), + nullable=False, + ), + ) + + # Migrate existing data from `run_metadata` to `run_metadata_resource` + connection = op.get_bind() + + # Fetch data from the existing `run_metadata` table + run_metadata_data = connection.execute( + sa.text(""" + SELECT id, resource_id, resource_type + FROM run_metadata + """) + ).fetchall() + + # Insert data into the new `run_metadata_resource` table + for row in run_metadata_data: + # Insert resource data with reference to `run_metadata` + connection.execute( + sa.text(""" + INSERT INTO run_metadata_resource_link (resource_id, resource_type, run_metadata_id) + VALUES (:id, :resource_id, :resource_type, :run_metadata_id) + """), + { + "resource_id": row.resource_id, + "resource_type": row.resource_type, + "run_metadata_id": row.id, + }, + ) + + # Drop the old `resource_id` and `resource_type` columns from `run_metadata` + op.drop_column("run_metadata", "resource_id") + op.drop_column("run_metadata", "resource_type") + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + pass diff --git a/src/zenml/zen_stores/schemas/__init__.py b/src/zenml/zen_stores/schemas/__init__.py index 2faf233723a..5375614e3fd 100644 --- a/src/zenml/zen_stores/schemas/__init__.py +++ b/src/zenml/zen_stores/schemas/__init__.py @@ -39,7 +39,10 @@ from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema from zenml.zen_stores.schemas.pipeline_schemas import PipelineSchema from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema -from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema +from zenml.zen_stores.schemas.run_metadata_schemas import ( + RunMetadataResourceLinkSchema, + RunMetadataSchema, +) from zenml.zen_stores.schemas.schedule_schema import ScheduleSchema from zenml.zen_stores.schemas.secret_schemas import SecretSchema from zenml.zen_stores.schemas.service_schemas import ServiceSchema @@ -90,6 +93,7 @@ "PipelineDeploymentSchema", "PipelineRunSchema", "PipelineSchema", + "RunMetadataResourceLinkSchema", "RunMetadataSchema", "ScheduleSchema", "SecretSchema", diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index 8b08e51b562..f415cfe308d 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -59,7 +59,9 @@ from zenml.zen_stores.schemas.model_schemas import ( ModelVersionArtifactSchema, ) - from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema + from zenml.zen_stores.schemas.run_metadata_schemas import ( + RunMetadataResourceLinkSchema, + ) from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema @@ -242,10 +244,10 @@ class ArtifactVersionSchema(BaseSchema, table=True): workspace: "WorkspaceSchema" = Relationship( back_populates="artifact_versions" ) - run_metadata: List["RunMetadataSchema"] = Relationship( + run_metadata_links: List["RunMetadataResourceLinkSchema"] = Relationship( back_populates="artifact_version", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataSchema.resource_id)==ArtifactVersionSchema.id)", + primaryjoin=f"and_(RunMetadataResourceLinkSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceLinkSchema.resource_id)==ArtifactVersionSchema.id)", cascade="delete", overlaps="run_metadata", ), @@ -376,7 +378,8 @@ def to_model( producer_step_run_id=producer_step_run_id, visualizations=[v.to_model() for v in self.visualizations], run_metadata={ - m.key: json.loads(m.value) for m in self.run_metadata + m.run_metadata.key: json.loads(m.run_metadata.value) + for m in self.run_metadata_links }, ) diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 37cec2c5513..ed34311baf5 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -47,7 +47,9 @@ from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema from zenml.zen_stores.schemas.constants import MODEL_VERSION_TABLENAME from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema -from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema +from zenml.zen_stores.schemas.run_metadata_schemas import ( + RunMetadataResourceLinkSchema, +) from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema from zenml.zen_stores.schemas.user_schemas import UserSchema @@ -303,10 +305,10 @@ class ModelVersionSchema(NamedSchema, table=True): description: str = Field(sa_column=Column(TEXT, nullable=True)) stage: str = Field(sa_column=Column(TEXT, nullable=True)) - run_metadata: List["RunMetadataSchema"] = Relationship( + run_metadata_links: List["RunMetadataResourceLinkSchema"] = Relationship( back_populates="model_version", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataSchema.resource_id)==ModelVersionSchema.id)", + primaryjoin=f"and_(RunMetadataResourceLinkSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceLinkSchema.resource_id)==ModelVersionSchema.id)", cascade="delete", overlaps="run_metadata", ), @@ -404,7 +406,8 @@ def to_model( workspace=self.workspace.to_model(), description=self.description, run_metadata={ - rm.key: json.loads(rm.value) for rm in self.run_metadata + m.run_metadata.key: json.loads(m.run_metadata.value) + for m in self.run_metadata_links }, ) diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 6028451acf2..25db79b5200 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -56,7 +56,9 @@ ModelVersionPipelineRunSchema, ModelVersionSchema, ) - from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema + from zenml.zen_stores.schemas.run_metadata_schemas import ( + RunMetadataResourceLinkSchema, + ) from zenml.zen_stores.schemas.service_schemas import ServiceSchema from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema @@ -136,10 +138,10 @@ class PipelineRunSchema(NamedSchema, table=True): ) workspace: "WorkspaceSchema" = Relationship(back_populates="runs") user: Optional["UserSchema"] = Relationship(back_populates="runs") - run_metadata: List["RunMetadataSchema"] = Relationship( + run_metadata_links: List["RunMetadataResourceLinkSchema"] = Relationship( back_populates="pipeline_run", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataSchema.resource_id)==PipelineRunSchema.id)", + primaryjoin=f"and_(RunMetadataResourceLinkSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceLinkSchema.resource_id)==PipelineRunSchema.id)", cascade="delete", overlaps="run_metadata", ), @@ -276,8 +278,8 @@ def to_model( ) run_metadata = { - metadata_schema.key: json.loads(metadata_schema.value) - for metadata_schema in self.run_metadata + m.run_metadata.key: json.loads(m.run_metadata.value) + for m in self.run_metadata_links } if self.deployment is not None: diff --git a/src/zenml/zen_stores/schemas/run_metadata_schemas.py b/src/zenml/zen_stores/schemas/run_metadata_schemas.py index 18d203111c7..9b97aa6fd82 100644 --- a/src/zenml/zen_stores/schemas/run_metadata_schemas.py +++ b/src/zenml/zen_stores/schemas/run_metadata_schemas.py @@ -17,7 +17,7 @@ from uuid import UUID from sqlalchemy import TEXT, VARCHAR, Column -from sqlmodel import Field, Relationship +from sqlmodel import Field, Relationship, SQLModel from zenml.enums import MetadataResourceTypes from zenml.zen_stores.schemas.base_schemas import BaseSchema @@ -38,8 +38,11 @@ class RunMetadataSchema(BaseSchema, table=True): __tablename__ = "run_metadata" - resource_id: UUID - resource_type: str = Field(sa_column=Column(VARCHAR(255), nullable=False)) + # Relationship to link to resources + resources: List["RunMetadataResourceLinkSchema"] = Relationship( + back_populates="run_metadata" + ) + pipeline_run: List["PipelineRunSchema"] = Relationship( back_populates="run_metadata", sa_relationship_kwargs=dict( @@ -103,3 +106,44 @@ class RunMetadataSchema(BaseSchema, table=True): key: str value: str = Field(sa_column=Column(TEXT, nullable=False)) type: str + + +class RunMetadataResourceLinkSchema(SQLModel, table=True): + """Table for linking resources to run metadata entries.""" + + __tablename__ = "run_metadata_resource_link" + + resource_id: UUID + resource_type: str = Field(sa_column=Column(VARCHAR(255), nullable=False)) + run_metadata_id: int = Field(foreign_key="run_metadata.id") + + # Relationship back to the base metadata table + run_metadata: RunMetadataSchema = Relationship(back_populates="resources") + + # Relationship to link specific resource types + pipeline_run: List["PipelineRunSchema"] = Relationship( + back_populates="run_metadata_links", + sa_relationship_kwargs=dict( + primaryjoin=f"and_(RunMetadataResource.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResource.resource_id)==PipelineRunSchema.id)" + ), + ) + step_run: List["StepRunSchema"] = Relationship( + back_populates="run_metadata_links", + sa_relationship_kwargs=dict( + primaryjoin=f"and_(RunMetadataResource.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResource.resource_id)==StepRunSchema.id)" + ), + ) + artifact_version: List["ArtifactVersionSchema"] = Relationship( + back_populates="run_metadata_links", + sa_relationship_kwargs=dict( + primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataSchema.resource_id)==ArtifactVersionSchema.id)", + overlaps="run_metadata,pipeline_run,step_run,model_version", + ), + ) + model_version: List["ModelVersionSchema"] = Relationship( + back_populates="run_metadata_links", + sa_relationship_kwargs=dict( + primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataSchema.resource_id)==ModelVersionSchema.id)", + overlaps="run_metadata,pipeline_run,step_run,artifact_version", + ), + ) diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index 8500db9715d..f74b21614e3 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -56,7 +56,9 @@ from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema from zenml.zen_stores.schemas.logs_schemas import LogsSchema from zenml.zen_stores.schemas.model_schemas import ModelVersionSchema - from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema + from zenml.zen_stores.schemas.run_metadata_schemas import ( + RunMetadataResourceLinkSchema, + ) class StepRunSchema(NamedSchema, table=True): @@ -139,10 +141,10 @@ class StepRunSchema(NamedSchema, table=True): deployment: Optional["PipelineDeploymentSchema"] = Relationship( back_populates="step_runs" ) - run_metadata: List["RunMetadataSchema"] = Relationship( + run_metadata_links: List["RunMetadataResourceLinkSchema"] = Relationship( back_populates="step_run", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataSchema.resource_id)==StepRunSchema.id)", + primaryjoin=f"and_(RunMetadataResourceLinkSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceLinkSchema.resource_id)==StepRunSchema.id)", cascade="delete", overlaps="run_metadata", ), @@ -219,8 +221,8 @@ def to_model( or a step_configuration. """ run_metadata = { - metadata_schema.key: json.loads(metadata_schema.value) - for metadata_schema in self.run_metadata + m.run_metadata.key: json.loads(m.run_metadata.value) + for m in self.run_metadata_links } input_artifacts = { diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index f9c314f774b..b6ed05d0dd3 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -325,6 +325,7 @@ PipelineDeploymentSchema, PipelineRunSchema, PipelineSchema, + RunMetadataResourceLinkSchema, RunMetadataSchema, RunTemplateSchema, ScheduleSchema, @@ -5516,20 +5517,29 @@ def create_run_metadata(self, run_metadata: RunMetadataRequest) -> None: The created run metadata. """ with Session(self.engine) as session: - for key, value in run_metadata.values.items(): - type_ = run_metadata.types[key] - run_metadata_schema = RunMetadataSchema( - workspace_id=run_metadata.workspace, - user_id=run_metadata.user, - resource_id=run_metadata.resource_id, - resource_type=run_metadata.resource_type.value, - stack_component_id=run_metadata.stack_component_id, - key=key, - value=json.dumps(value), - type=type_, - ) - session.add(run_metadata_schema) - session.commit() + if run_metadata.resources: + for key, value in run_metadata.values.items(): + type_ = run_metadata.types[key] + run_metadata_schema = RunMetadataSchema( + workspace_id=run_metadata.workspace, + user_id=run_metadata.user, + stack_component_id=run_metadata.stack_component_id, + key=key, + value=json.dumps(value), + type=type_, + ) + session.add(run_metadata_schema) + session.commit() + + for resource in run_metadata.resources: + rm_resource_link = RunMetadataResourceLinkSchema( + resource_id=resource[0], + resource_type=resource[1].value, + run_metadata_id=run_metadata_schema.id, + ) + session.add(rm_resource_link) + session.commit() + return None # ----------------------------- Schedules ----------------------------- diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index bd9583a8a1d..6c916326342 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -484,8 +484,7 @@ def test_create_run_metadata_for_pipeline_run(clean_client_with_run: Client): # Assert that the created metadata is correct clean_client_with_run.create_run_metadata( metadata={"axel": "is awesome"}, - resource_id=pipeline_run.id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, + resources=[(pipeline_run.id, MetadataResourceTypes.PIPELINE_RUN)], ) rm = clean_client_with_run.get_pipeline_run(pipeline_run.id).run_metadata @@ -501,8 +500,7 @@ def test_create_run_metadata_for_step_run(clean_client_with_run: Client): # Assert that the created metadata is correct clean_client_with_run.create_run_metadata( metadata={"axel": "is awesome"}, - resource_id=step_run.id, - resource_type=MetadataResourceTypes.STEP_RUN, + resources=[(step_run.id, MetadataResourceTypes.STEP_RUN)], ) rm = clean_client_with_run.get_run_step(step_run.id).run_metadata @@ -518,8 +516,9 @@ def test_create_run_metadata_for_artifact(clean_client_with_run: Client): # Assert that the created metadata is correct clean_client_with_run.create_run_metadata( metadata={"axel": "is awesome"}, - resource_id=artifact_version.id, - resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + resources=[ + (artifact_version.id, MetadataResourceTypes.ARTIFACT_VERSION) + ], ) rm = clean_client_with_run.get_artifact_version( diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index ae4a4011108..c894d567c71 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -5500,8 +5500,7 @@ def test_metadata_full_cycle_with_cascade_deletion( RunMetadataRequest( user=client.active_user.id, workspace=client.active_workspace.id, - resource_id=resource.id, - resource_type=type_, + resources=[(resource.id, type_)], values={"foo": "bar"}, types={"foo": MetadataTypeEnum.STRING}, stack_component_id=sc.id From a68f3c29668d61d00cb01a4c7b8f4c60e87a184f Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 12 Nov 2024 17:03:31 +0100 Subject: [PATCH 054/124] second checkpoint --- src/zenml/models/v2/core/artifact_version.py | 8 +- src/zenml/models/v2/core/model_version.py | 8 +- src/zenml/models/v2/core/pipeline_run.py | 8 +- src/zenml/models/v2/core/step_run.py | 9 +- .../cc269488e5a9_separate_run_metadata.py | 94 +++++++++++++++---- src/zenml/zen_stores/schemas/__init__.py | 4 +- .../zen_stores/schemas/artifact_schemas.py | 6 +- src/zenml/zen_stores/schemas/model_schemas.py | 6 +- .../schemas/pipeline_run_schemas.py | 6 +- .../schemas/run_metadata_schemas.py | 50 +++------- .../zen_stores/schemas/step_run_schemas.py | 6 +- src/zenml/zen_stores/sql_zen_store.py | 4 +- 12 files changed, 117 insertions(+), 92 deletions(-) diff --git a/src/zenml/models/v2/core/artifact_version.py b/src/zenml/models/v2/core/artifact_version.py index fa328c11c2d..d7b9f9d0849 100644 --- a/src/zenml/models/v2/core/artifact_version.py +++ b/src/zenml/models/v2/core/artifact_version.py @@ -569,7 +569,7 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: ModelSchema, ModelVersionArtifactSchema, PipelineRunSchema, - RunMetadataResourceLinkSchema, + RunMetadataResourceSchema, RunMetadataSchema, StepRunInputArtifactSchema, StepRunOutputArtifactSchema, @@ -657,11 +657,11 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: for key, value in self.run_metadata.items(): additional_filter = and_( - RunMetadataResourceLinkSchema.resource_id + RunMetadataResourceSchema.resource_id == ArtifactVersionSchema.id, - RunMetadataResourceLinkSchema.resource_type + RunMetadataResourceSchema.resource_type == MetadataResourceTypes.ARTIFACT_VERSION, - RunMetadataResourceLinkSchema.run_metadata_id + RunMetadataResourceSchema.run_metadata_id == RunMetadataSchema.id, self.generate_custom_query_conditions_for_column( value=value, diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index 02cb45ed1e5..05a15de8026 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -656,7 +656,7 @@ def get_custom_filters( from zenml.zen_stores.schemas import ( ModelVersionSchema, - RunMetadataResourceLinkSchema, + RunMetadataResourceSchema, RunMetadataSchema, UserSchema, ) @@ -677,11 +677,11 @@ def get_custom_filters( for key, value in self.run_metadata.items(): additional_filter = and_( - RunMetadataResourceLinkSchema.resource_id + RunMetadataResourceSchema.resource_id == ModelVersionSchema.id, - RunMetadataResourceLinkSchema.resource_type + RunMetadataResourceSchema.resource_type == MetadataResourceTypes.MODEL_VERSION, - RunMetadataResourceLinkSchema.run_metadata_id + RunMetadataResourceSchema.run_metadata_id == RunMetadataSchema.id, self.generate_custom_query_conditions_for_column( value=value, diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index 60e653ac843..a9e380780ab 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -722,7 +722,7 @@ def get_custom_filters( PipelineDeploymentSchema, PipelineRunSchema, PipelineSchema, - RunMetadataResourceLinkSchema, + RunMetadataResourceSchema, RunMetadataSchema, ScheduleSchema, StackComponentSchema, @@ -898,11 +898,11 @@ def get_custom_filters( for key, value in self.run_metadata.items(): additional_filter = and_( - RunMetadataResourceLinkSchema.resource_id + RunMetadataResourceSchema.resource_id == PipelineRunSchema.id, - RunMetadataResourceLinkSchema.resource_type + RunMetadataResourceSchema.resource_type == MetadataResourceTypes.PIPELINE_RUN, - RunMetadataResourceLinkSchema.run_metadata_id + RunMetadataResourceSchema.run_metadata_id == RunMetadataSchema.id, self.generate_custom_query_conditions_for_column( value=value, diff --git a/src/zenml/models/v2/core/step_run.py b/src/zenml/models/v2/core/step_run.py index 29557c1a554..bdfc04dce20 100644 --- a/src/zenml/models/v2/core/step_run.py +++ b/src/zenml/models/v2/core/step_run.py @@ -594,7 +594,7 @@ def get_custom_filters( from zenml.zen_stores.schemas import ( ModelSchema, ModelVersionSchema, - RunMetadataResourceLinkSchema, + RunMetadataResourceSchema, RunMetadataSchema, StepRunSchema, ) @@ -613,11 +613,10 @@ def get_custom_filters( for key, value in self.run_metadata.items(): additional_filter = and_( - RunMetadataResourceLinkSchema.resource_id - == StepRunSchema.id, - RunMetadataResourceLinkSchema.resource_type + RunMetadataResourceSchema.resource_id == StepRunSchema.id, + RunMetadataResourceSchema.resource_type == MetadataResourceTypes.STEP_RUN, - RunMetadataResourceLinkSchema.run_metadata_id + RunMetadataResourceSchema.run_metadata_id == RunMetadataSchema.id, self.generate_custom_query_conditions_for_column( value=value, diff --git a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py index 82050042f16..2c6d69f60e3 100644 --- a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py +++ b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py @@ -1,11 +1,12 @@ -"""separate run metadata [cc269488e5a9]. +"""Separate run metadata into resource link table with new UUIDs. Revision ID: cc269488e5a9 Revises: 904464ea4041 Create Date: 2024-11-12 09:46:46.587478 - """ +import uuid + import sqlalchemy as sa import sqlmodel from alembic import op @@ -18,10 +19,16 @@ def upgrade() -> None: - """Creates the 'run_metadata_resource_link' table.""" - # Create the `run_metadata_resource_link` table + """Creates the 'run_metadata_resource' table and migrates data.""" + # Create the `run_metadata_resource` table op.create_table( - "run_metadata_resource_link", + "run_metadata_resource", + sa.Column( + "id", + sqlmodel.sql.sqltypes.GUID(), + nullable=False, + primary_key=True, + ), sa.Column("resource_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), sa.Column("resource_type", sa.String(length=255), nullable=False), sa.Column( @@ -35,7 +42,7 @@ def upgrade() -> None: # Migrate existing data from `run_metadata` to `run_metadata_resource` connection = op.get_bind() - # Fetch data from the existing `run_metadata` table + # Fetch existing `run_metadata` data run_metadata_data = connection.execute( sa.text(""" SELECT id, resource_id, resource_type @@ -43,26 +50,73 @@ def upgrade() -> None: """) ).fetchall() - # Insert data into the new `run_metadata_resource` table - for row in run_metadata_data: - # Insert resource data with reference to `run_metadata` + # Prepare data with new UUIDs for bulk insert + resource_data = [ + { + "id": str(uuid.uuid4()), # Generate a new UUID for each row + "resource_id": row.resource_id, + "resource_type": row.resource_type, + "run_metadata_id": row.id, + } + for row in run_metadata_data + ] + + # Perform bulk insert into `run_metadata_resource` + op.bulk_insert( + sa.table( + "run_metadata_resource", + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column( + "resource_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column("resource_type", sa.String(length=255), nullable=False), + sa.Column("run_metadata_id", sa.Integer, nullable=False), + ), + resource_data, + ) + + # Drop the old `resource_id` and `resource_type` columns from `run_metadata` + op.drop_column("run_metadata", "resource_id") + op.drop_column("run_metadata", "resource_type") + + +def downgrade() -> None: + """Reverts the 'run_metadata_resource' table and migrates data back.""" + # Recreate the `resource_id` and `resource_type` columns in `run_metadata` + op.add_column( + "run_metadata", + sa.Column("resource_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + ) + op.add_column( + "run_metadata", + sa.Column("resource_type", sa.String(length=255), nullable=True), + ) + + # Migrate data back from `run_metadata_resource` to `run_metadata` + connection = op.get_bind() + + # Fetch data from `run_metadata_resource` + run_metadata_resource_data = connection.execute( + sa.text(""" + SELECT resource_id, resource_type, run_metadata_id + FROM run_metadata_resource + """) + ).fetchall() + + # Update `run_metadata` with the data from `run_metadata_resource` + for row in run_metadata_resource_data: connection.execute( sa.text(""" - INSERT INTO run_metadata_resource_link (resource_id, resource_type, run_metadata_id) - VALUES (:id, :resource_id, :resource_type, :run_metadata_id) + UPDATE run_metadata + SET resource_id = :resource_id, resource_type = :resource_type + WHERE id = :run_metadata_id """), { "resource_id": row.resource_id, "resource_type": row.resource_type, - "run_metadata_id": row.id, + "run_metadata_id": row.run_metadata_id, }, ) - # Drop the old `resource_id` and `resource_type` columns from `run_metadata` - op.drop_column("run_metadata", "resource_id") - op.drop_column("run_metadata", "resource_type") - - -def downgrade() -> None: - """Downgrade database schema and/or data back to the previous revision.""" - pass + # Drop the `run_metadata_resource` table + op.drop_table("run_metadata_resource") diff --git a/src/zenml/zen_stores/schemas/__init__.py b/src/zenml/zen_stores/schemas/__init__.py index 5375614e3fd..dadb2b747f6 100644 --- a/src/zenml/zen_stores/schemas/__init__.py +++ b/src/zenml/zen_stores/schemas/__init__.py @@ -40,7 +40,7 @@ from zenml.zen_stores.schemas.pipeline_schemas import PipelineSchema from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema from zenml.zen_stores.schemas.run_metadata_schemas import ( - RunMetadataResourceLinkSchema, + RunMetadataResourceSchema, RunMetadataSchema, ) from zenml.zen_stores.schemas.schedule_schema import ScheduleSchema @@ -93,7 +93,7 @@ "PipelineDeploymentSchema", "PipelineRunSchema", "PipelineSchema", - "RunMetadataResourceLinkSchema", + "RunMetadataResourceSchema", "RunMetadataSchema", "ScheduleSchema", "SecretSchema", diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index f415cfe308d..70cf0ab0df4 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -60,7 +60,7 @@ ModelVersionArtifactSchema, ) from zenml.zen_stores.schemas.run_metadata_schemas import ( - RunMetadataResourceLinkSchema, + RunMetadataResourceSchema, ) from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema @@ -244,10 +244,10 @@ class ArtifactVersionSchema(BaseSchema, table=True): workspace: "WorkspaceSchema" = Relationship( back_populates="artifact_versions" ) - run_metadata_links: List["RunMetadataResourceLinkSchema"] = Relationship( + run_metadata_links: List["RunMetadataResourceSchema"] = Relationship( back_populates="artifact_version", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataResourceLinkSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceLinkSchema.resource_id)==ArtifactVersionSchema.id)", + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)", cascade="delete", overlaps="run_metadata", ), diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index ed34311baf5..e0c103c78a9 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -48,7 +48,7 @@ from zenml.zen_stores.schemas.constants import MODEL_VERSION_TABLENAME from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema from zenml.zen_stores.schemas.run_metadata_schemas import ( - RunMetadataResourceLinkSchema, + RunMetadataResourceSchema, ) from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema @@ -305,10 +305,10 @@ class ModelVersionSchema(NamedSchema, table=True): description: str = Field(sa_column=Column(TEXT, nullable=True)) stage: str = Field(sa_column=Column(TEXT, nullable=True)) - run_metadata_links: List["RunMetadataResourceLinkSchema"] = Relationship( + run_metadata_links: List["RunMetadataResourceSchema"] = Relationship( back_populates="model_version", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataResourceLinkSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceLinkSchema.resource_id)==ModelVersionSchema.id)", + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)", cascade="delete", overlaps="run_metadata", ), diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 25db79b5200..052c3f332d0 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -57,7 +57,7 @@ ModelVersionSchema, ) from zenml.zen_stores.schemas.run_metadata_schemas import ( - RunMetadataResourceLinkSchema, + RunMetadataResourceSchema, ) from zenml.zen_stores.schemas.service_schemas import ServiceSchema from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema @@ -138,10 +138,10 @@ class PipelineRunSchema(NamedSchema, table=True): ) workspace: "WorkspaceSchema" = Relationship(back_populates="runs") user: Optional["UserSchema"] = Relationship(back_populates="runs") - run_metadata_links: List["RunMetadataResourceLinkSchema"] = Relationship( + run_metadata_links: List["RunMetadataResourceSchema"] = Relationship( back_populates="pipeline_run", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataResourceLinkSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceLinkSchema.resource_id)==PipelineRunSchema.id)", + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)", cascade="delete", overlaps="run_metadata", ), diff --git a/src/zenml/zen_stores/schemas/run_metadata_schemas.py b/src/zenml/zen_stores/schemas/run_metadata_schemas.py index 9b97aa6fd82..89acf4f1449 100644 --- a/src/zenml/zen_stores/schemas/run_metadata_schemas.py +++ b/src/zenml/zen_stores/schemas/run_metadata_schemas.py @@ -14,10 +14,10 @@ """SQLModel implementation of pipeline run metadata tables.""" from typing import TYPE_CHECKING, List, Optional -from uuid import UUID +from uuid import UUID, uuid4 from sqlalchemy import TEXT, VARCHAR, Column -from sqlmodel import Field, Relationship, SQLModel +from sqlmodel import Field, Relationship from zenml.enums import MetadataResourceTypes from zenml.zen_stores.schemas.base_schemas import BaseSchema @@ -39,38 +39,9 @@ class RunMetadataSchema(BaseSchema, table=True): __tablename__ = "run_metadata" # Relationship to link to resources - resources: List["RunMetadataResourceLinkSchema"] = Relationship( + resources: List["RunMetadataResourceSchema"] = Relationship( back_populates="run_metadata" ) - - pipeline_run: List["PipelineRunSchema"] = Relationship( - back_populates="run_metadata", - sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataSchema.resource_id)==PipelineRunSchema.id)", - overlaps="run_metadata,step_run,artifact_version,model_version", - ), - ) - step_run: List["StepRunSchema"] = Relationship( - back_populates="run_metadata", - sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataSchema.resource_id)==StepRunSchema.id)", - overlaps="run_metadata,pipeline_run,artifact_version,model_version", - ), - ) - artifact_version: List["ArtifactVersionSchema"] = Relationship( - back_populates="run_metadata", - sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataSchema.resource_id)==ArtifactVersionSchema.id)", - overlaps="run_metadata,pipeline_run,step_run,model_version", - ), - ) - model_version: List["ModelVersionSchema"] = Relationship( - back_populates="run_metadata", - sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataSchema.resource_id)==ModelVersionSchema.id)", - overlaps="run_metadata,pipeline_run,step_run,artifact_version", - ), - ) stack_component_id: Optional[UUID] = build_foreign_key_field( source=__tablename__, target=StackComponentSchema.__tablename__, @@ -108,14 +79,15 @@ class RunMetadataSchema(BaseSchema, table=True): type: str -class RunMetadataResourceLinkSchema(SQLModel, table=True): +class RunMetadataResourceSchema(BaseSchema, table=True): """Table for linking resources to run metadata entries.""" - __tablename__ = "run_metadata_resource_link" + __tablename__ = "run_metadata_resource" + id: UUID = Field(default_factory=uuid4, primary_key=True) resource_id: UUID resource_type: str = Field(sa_column=Column(VARCHAR(255), nullable=False)) - run_metadata_id: int = Field(foreign_key="run_metadata.id") + run_metadata_id: UUID = Field(foreign_key="run_metadata.id") # Relationship back to the base metadata table run_metadata: RunMetadataSchema = Relationship(back_populates="resources") @@ -124,26 +96,26 @@ class RunMetadataResourceLinkSchema(SQLModel, table=True): pipeline_run: List["PipelineRunSchema"] = Relationship( back_populates="run_metadata_links", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataResource.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResource.resource_id)==PipelineRunSchema.id)" + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)" ), ) step_run: List["StepRunSchema"] = Relationship( back_populates="run_metadata_links", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataResource.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResource.resource_id)==StepRunSchema.id)" + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)" ), ) artifact_version: List["ArtifactVersionSchema"] = Relationship( back_populates="run_metadata_links", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataSchema.resource_id)==ArtifactVersionSchema.id)", + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)", overlaps="run_metadata,pipeline_run,step_run,model_version", ), ) model_version: List["ModelVersionSchema"] = Relationship( back_populates="run_metadata_links", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataSchema.resource_id)==ModelVersionSchema.id)", + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)", overlaps="run_metadata,pipeline_run,step_run,artifact_version", ), ) diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index f74b21614e3..c8b1ded8789 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -57,7 +57,7 @@ from zenml.zen_stores.schemas.logs_schemas import LogsSchema from zenml.zen_stores.schemas.model_schemas import ModelVersionSchema from zenml.zen_stores.schemas.run_metadata_schemas import ( - RunMetadataResourceLinkSchema, + RunMetadataResourceSchema, ) @@ -141,10 +141,10 @@ class StepRunSchema(NamedSchema, table=True): deployment: Optional["PipelineDeploymentSchema"] = Relationship( back_populates="step_runs" ) - run_metadata_links: List["RunMetadataResourceLinkSchema"] = Relationship( + run_metadata_links: List["RunMetadataResourceSchema"] = Relationship( back_populates="step_run", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataResourceLinkSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceLinkSchema.resource_id)==StepRunSchema.id)", + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)", cascade="delete", overlaps="run_metadata", ), diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index b6ed05d0dd3..1d7ad2603cc 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -325,7 +325,7 @@ PipelineDeploymentSchema, PipelineRunSchema, PipelineSchema, - RunMetadataResourceLinkSchema, + RunMetadataResourceSchema, RunMetadataSchema, RunTemplateSchema, ScheduleSchema, @@ -5532,7 +5532,7 @@ def create_run_metadata(self, run_metadata: RunMetadataRequest) -> None: session.commit() for resource in run_metadata.resources: - rm_resource_link = RunMetadataResourceLinkSchema( + rm_resource_link = RunMetadataResourceSchema( resource_id=resource[0], resource_type=resource[1].value, run_metadata_id=run_metadata_schema.id, From 07ece0ca3579cd32fdd6b026094e7465789d0bfd Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 13 Nov 2024 18:40:50 +0100 Subject: [PATCH 055/124] fixing revisions --- .../migrations/versions/cc269488e5a9_separate_run_metadata.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py index 2c6d69f60e3..6af578a7fef 100644 --- a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py +++ b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py @@ -1,7 +1,7 @@ """Separate run metadata into resource link table with new UUIDs. Revision ID: cc269488e5a9 -Revises: 904464ea4041 +Revises: 0.70.0 Create Date: 2024-11-12 09:46:46.587478 """ @@ -13,7 +13,7 @@ # revision identifiers, used by Alembic. revision = "cc269488e5a9" -down_revision = "904464ea4041" +down_revision = "0.70.0" branch_labels = None depends_on = None From 4b6f84ade894b6bae1ee8e01302d81845b708d22 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 13 Nov 2024 18:44:12 +0100 Subject: [PATCH 056/124] adding overlap to remove warnings --- src/zenml/zen_stores/schemas/run_metadata_schemas.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/zenml/zen_stores/schemas/run_metadata_schemas.py b/src/zenml/zen_stores/schemas/run_metadata_schemas.py index feb75119f70..b3c3d713996 100644 --- a/src/zenml/zen_stores/schemas/run_metadata_schemas.py +++ b/src/zenml/zen_stores/schemas/run_metadata_schemas.py @@ -104,13 +104,15 @@ class RunMetadataResourceSchema(SQLModel, table=True): pipeline_run: List["PipelineRunSchema"] = Relationship( back_populates="run_metadata", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)" - ), + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)", + overlaps="run_metadata,step_run,artifact_version,model_version", + ) ) step_run: List["StepRunSchema"] = Relationship( back_populates="run_metadata", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)" + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)", + overlaps="run_metadata,pipeline_run,artifact_version,model_version", ), ) artifact_version: List["ArtifactVersionSchema"] = Relationship( From dca59131a723512ec175511ff0e626164b65a2a3 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 13 Nov 2024 19:58:58 +0100 Subject: [PATCH 057/124] complete docs changes --- .../track-metrics-metadata/README.md | 42 +++++++- .../attach-metadata-to-a-model.md | 62 ++++++++--- .../attach-metadata-to-a-run.md | 85 +++++++++++++++ .../attach-metadata-to-a-step.md | 100 ++++++++++++++++++ .../attach-metadata-to-an-artifact.md | 82 ++++++++++---- .../attach-metadata-to-steps.md | 65 ------------ .../fetch-metadata-within-steps.md | 6 +- .../grouping-metadata.md | 18 +++- .../logging-metadata.md | 16 +-- .../build-pipelines/README.md | 2 +- docs/book/toc.md | 5 +- 11 files changed, 361 insertions(+), 122 deletions(-) create mode 100644 docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md create mode 100644 docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md delete mode 100644 docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-steps.md diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/README.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/README.md index df281351c70..fd27d792107 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/README.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/README.md @@ -5,10 +5,44 @@ description: Tracking metrics and metadata # Track metrics and metadata -Logging metrics and metadata is standardized in ZenML. The most common pattern is to use the `log_xxx` methods, e.g.: +ZenML provides a unified way to log and manage metrics and metadata through +the `log_metadata` function. This versatile function allows you to log +metadata across various entities like models, artifacts, steps, and runs +through a single interface. Additionally, you can adjust if you want to +automatically the same metadata for the related entities. -* Log metadata to a [model](attach-metadata-to-a-model.md): `log_model_metadata` -* Log metadata to an [artifact](attach-metadata-to-an-artifact.md): `log_artifact_metadata` -* Log metadata to a [step](attach-metadata-to-steps.md): `log_step_metadata` +### The most basic use-case + +You can use the `log_metadata` function within a step: + +```python +from zenml import step, log_metadata + +@step +def my_step() -> ...: + log_metadata(metadata={"accuracy": 0.91}) + ... +``` + +This will log the `accuracy` for the step, its pipeline run, and if provided +its model version. + +### Additional use-cases + +The `log_metadata` function also supports various use-cases by allowing you to +specify the target entity (e.g., model, artifact, step, or run) with flexible +parameters. You can learn more about these use-cases in the following pages: + +- [Log metadata to a step](attach-metadata-to-a-step.md) +- [Log metadata to a run](attach-metadata-to-a-run.md) +- [Log metadata to an artifact](attach-metadata-to-an-artifact.md) +- [Log metadata to a model](attach-metadata-to-a-model.md) + +{% hint style="warning" %} +The older methods for logging metadata to specific entities, such as +`log_model_metadata`, `log_artifact_metadata`, and `log_step_metadata`, are +now deprecated. It is recommended to use `log_metadata` for all future +implementations. +{% endhint %}
ZenML Scarf
diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md index 0a76bf1bc70..d6a3717daea 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md @@ -1,34 +1,46 @@ --- -description: >- - Attach any metadata as key-value pairs to your models for future reference and - auditability. +description: Learn how to attach metadata to a model. --- # Attach metadata to a model +ZenML allows you to log metadata for models, which provides additional context +that goes beyond individual artifact details. Model metadata can represent +high-level insights, such as evaluation results, deployment information, +or customer-specific details, making it easier to manage and interpret +the model's usage and performance across different versions. + ## Logging Metadata for Models -While artifact metadata is specific to individual outputs of steps, model metadata encapsulates broader and more general information that spans across multiple artifacts. For example, evaluation results or the name of a customer for whom the model is intended could be logged with the model. +To log metadata for a model, use the `log_metadata` function. This function +lets you attach key-value metadata to a model, which can include metrics and +other JSON-serializable values, such as custom ZenML types like `Uri`, +`Path`, and `StorageSize`. Here's an example of logging metadata for a model: ```python -from zenml import step, log_model_metadata, ArtifactConfig, get_step_context from typing import Annotated + import pandas as pd -from sklearn.ensemble import RandomForestClassifier from sklearn.base import ClassifierMixin +from sklearn.ensemble import RandomForestClassifier + +from zenml import step, log_metadata, ArtifactConfig + @step -def train_model(dataset: pd.DataFrame) -> Annotated[ClassifierMixin, ArtifactConfig(name="sklearn_classifier", is_model_artifact=True)]: - """Train a model""" - # Fit the model and compute metrics +def train_model(dataset: pd.DataFrame) -> Annotated[ + ClassifierMixin, ArtifactConfig( + name="sklearn_classifier", is_model_artifact=True + ) +]: + """Train a model and log model metadata.""" classifier = RandomForestClassifier().fit(dataset) accuracy, precision, recall = ... # Log metadata for the model - # This associates the metadata with the ZenML model, not the artifact - log_model_metadata( + log_metadata( metadata={ "evaluation_metrics": { "accuracy": accuracy, @@ -36,19 +48,35 @@ def train_model(dataset: pd.DataFrame) -> Annotated[ClassifierMixin, ArtifactCon "recall": recall } }, - # Omitted model_name will use the model in the current context model_name="zenml_model_name", - # Omitted model_version will default to 'latest' - model_version="zenml_model_version", + model_version="zenml_model_version" ) return classifier ``` -In this example, the metadata is associated with the model rather than the specific classifier artifact. This is particularly useful when the metadata reflects an aggregation or summary of various steps and artifacts in the pipeline. +In this example, the metadata is associated with the model rather than the +specific classifier artifact. This is particularly useful when the metadata +reflects an aggregation or summary of various steps and artifacts in the +pipeline. + +### Selecting Models with `log_metadata` + +When using `log_metadata` with a model, ZenML provides flexible options to +attach metadata accurately: + +1. **Model Name and Version Provided**: If both a model name and version are + provided, ZenML will use these to identify and attach metadata to the + specific model version. +2. **Model Name Only**: If only a model name is provided, ZenML will attach + metadata to the latest version of the model. +3. **Model Version ID Provided**: If a model version ID is directly provided, + ZenML will use it to fetch and attach the metadata to that specific model + version. ## Fetching logged metadata -Once metadata has been logged in an [artifact](attach-metadata-to-an-artifact.md), model, or [step](attach-metadata-to-steps.md), we can easily fetch the metadata with the ZenML Client: +Once metadata has been attached to a model, it can be retrieved for inspection +or analysis using the ZenML Client. ```python from zenml.client import Client @@ -56,7 +84,7 @@ from zenml.client import Client client = Client() model = client.get_model_version("my_model", "my_version") -print(model.run_metadata["metadata_key"].value) +print(model.run_metadata["metadata_key"]) ```
ZenML Scarf
diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md new file mode 100644 index 00000000000..525b379e5b0 --- /dev/null +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md @@ -0,0 +1,85 @@ +--- +description: Attaching metadata to a run. +--- + +# Attach Metadata to a Run + +In ZenML, you can log metadata directly to a pipeline run, either during or +after execution, using the `log_metadata` function. This function allows you +to attach a dictionary of key-value pairs as metadata to a pipeline run, +with values that can be any JSON-serializable data type, including ZenML +custom types like `Uri`, `Path`, `DType`, and `StorageSize`. + +## Logging Metadata Within a Run + +If you are logging metadata from within a step that’s part of a pipeline run, +calling `log_metadata` will attach the specified metadata to the current +pipeline run. This is especially useful for logging details about the run +while it's still active. + +```python +from typing import Annotated + +import pandas as pd +from sklearn.base import ClassifierMixin +from sklearn.ensemble import RandomForestClassifier + +from zenml import step, log_metadata, ArtifactConfig + + +@step +def train_model(dataset: pd.DataFrame) -> Annotated[ + ClassifierMixin, + ArtifactConfig(name="sklearn_classifier", is_model_artifact=True) +]: + """Train a model and log run-level metadata.""" + classifier = RandomForestClassifier().fit(dataset) + accuracy, precision, recall = ... + + # Log metadata at the run level + log_metadata( + metadata={ + "run_metrics": { + "accuracy": accuracy, + "precision": precision, + "recall": recall + } + } + ) + return classifier +``` + +{% hint style="warning" %} +In order to log metadata to a pipeline run during the step execution without +specifying any additional identifiers, `log_related_entities` should be +`True` (default behaviour). +{% endhint %} + +## Logging Metadata Outside a Run + +You can also attach metadata to a specific pipeline run after its execution, +using identifiers like the run ID. This is useful when logging information or +metrics that were calculated post-execution. + +```python +from zenml import log_metadata + +log_metadata( + metadata={"post_run_info": {"some_metric": 5.0}}, + run_id_name_or_prefix="run_id_name_or_prefix" +) +``` + +## Fetching Logged Metadata + +Once metadata has been logged in a pipeline run, you can retrieve it using +the ZenML Client: + +```python +from zenml.client import Client + +client = Client() +run = client.get_pipeline_run("run_id_name_or_prefix") + +print(run.run_metadata["metadata_key"]) +``` \ No newline at end of file diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md new file mode 100644 index 00000000000..e87cc4d99f3 --- /dev/null +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md @@ -0,0 +1,100 @@ +--- +description: Learn how to attach metadata to a step. +--- + +# Attach metadata to a step + +In ZenML, you can log metadata for a specific step during or after its +execution by using the `log_metadata` function. This function allows you to +attach a dictionary of key-value pairs as metadata to a step. The metadata +can be any JSON-serializable value, including custom classes such as +`Uri`, `Path`, `DType`, and `StorageSize`. + +## Logging Metadata Within a Step + +If called within a step, `log_metadata` automatically attaches the metadata to +the currently executing step and its associated pipeline run. This is +ideal for logging metrics or information that becomes available during the +step execution. + +```python +from typing import Annotated + +import pandas as pd +from sklearn.base import ClassifierMixin +from sklearn.ensemble import RandomForestClassifier + +from zenml import step, log_metadata, ArtifactConfig + + +@step +def train_model(dataset: pd.DataFrame) -> Annotated[ + ClassifierMixin, + ArtifactConfig(name="sklearn_classifier", is_model_artifact=True) +]: + """Train a model and log evaluation metrics.""" + classifier = RandomForestClassifier().fit(dataset) + accuracy, precision, recall = ... + + # Log metadata at the step level + log_metadata( + metadata={ + "evaluation_metrics": { + "accuracy": accuracy, + "precision": precision, + "recall": recall + } + } + ) + return classifier +``` + +{% hint style="info" %} +If you do not want to log the same metadata for the related entries such as +the pipeline run and the model version, you can set the `log_related_entities` +to `False` when you call `log_metadata`. +{% endhint %} + + +## Logging Metadata Outside a Step + +You can also log metadata for a specific step after execution, using +identifiers to specify the pipeline, step, and run. This approach is +useful when you want to log metadata post-execution. + +```python +from zenml import log_metadata + +log_metadata( + metadata={ + "additional_info": {"a_number": 3} + }, + step_name="step_name", + run_id_name_or_prefix="run_id_name_or_prefix" +) + +# or + +log_metadata( + metadata={ + "additional_info": {"a_number": 3} + }, + step_id="step_id", +) +``` + +## Fetching logged metadata + +Once metadata has been logged in a step, we can easily fetch the metadata with +the ZenML Client: + +```python +from zenml.client import Client + +client = Client() +step = client.get_pipeline_run().steps["step_name"] + +print(step.run_metadata["metadata_key"]) +``` + +
ZenML Scarf
diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md index 01e068e7b3f..7fffe70a4a4 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md @@ -1,48 +1,82 @@ --- -description: Learn how to log metadata for artifacts and models in ZenML. +description: Learn how to attach metadata to an artifact. --- # Attach metadata to an artifact -![Metadata in the dashboard](../../.gitbook/assets/metadata-in-dashboard.png) +![Metadata in the dashboard](../../../.gitbook/assets/metadata-in-dashboard.png) -Metadata plays a critical role in ZenML, providing context and additional information about various entities within the platform. Anything which is `metadata` in ZenML can be compared in the dashboard. - -This guide will explain how to log metadata for artifacts and models in ZenML and detail the types of metadata that can be logged. +In ZenML, metadata enhances artifacts by adding context and important details, +such as size, structure, or performance metrics. This metadata is accessible +in the ZenML dashboard, making it easier to inspect, compare, and track +artifacts across pipeline runs. ## Logging Metadata for Artifacts -Artifacts in ZenML are outputs of steps within a pipeline, such as datasets, models, or evaluation results. Associating metadata with artifacts can help users understand the nature and characteristics of these outputs. +Artifacts in ZenML are outputs of steps within a pipeline, such as datasets, +models, or evaluation results. Associating metadata with artifacts can help +users understand the nature and characteristics of these outputs. -To log metadata for an artifact, you can use the `log_artifact_metadata` method. This method allows you to attach a dictionary of key-value pairs as metadata to an artifact. The metadata can be any JSON-serializable value, including custom classes such as `Uri`, `Path`, `DType`, and `StorageSize`. Find out more about these different types [here](../track-metrics-metadata/logging-metadata.md). +To log metadata for an artifact, use the `log_metadata` function, specifying +the artifact name, version, or ID. The metadata can be any JSON-serializable +value, including ZenML custom types like `Uri`, `Path`, `DType`, and +`StorageSize`. Here's an example of logging metadata for an artifact: ```python -from zenml import step, log_artifact_metadata +from typing import Annotated + +import pandas as pd + +from zenml import step, log_metadata from zenml.metadata.metadata_types import StorageSize + @step -def process_data_step(dataframe: pd.DataFrame) -> Annotated[pd.DataFrame, "processed_data"],: +def process_data_step(dataframe: pd.DataFrame) -> Annotated[ + pd.DataFrame, "processed_data" +]: """Process a dataframe and log metadata about the result.""" - # Perform processing on the dataframe... processed_dataframe = ... # Log metadata about the processed dataframe - log_artifact_metadata( + log_metadata( artifact_name="processed_data", metadata={ "row_count": len(processed_dataframe), "columns": list(processed_dataframe.columns), - "storage_size": StorageSize(processed_dataframe.memory_usage().sum()) + "storage_size": StorageSize( + processed_dataframe.memory_usage().sum()) } ) return processed_dataframe ``` +### Selecting the artifact to log the metadata to + +When using `log_metadata` with an artifact name, ZenML provides flexible +options to attach metadata to the correct artifact: + +1. **Name and Version Provided**: If both an artifact name and version are +provided, ZenML will use these to identify and attach metadata to the +specific artifact version. +2. **Name Only, Within a Step**: If only a name is provided and +`log_metadata` is called within a step, ZenML will try to locate the +corresponding output artifact within the step and attach the metadata to it. If +an output with the provided name does not exist in the step, check scenario 3. +3. **Name Only, Outside a Step**: If only a name is provided and +`log_metadata` is called outside a step, ZenML will attach metadata to the +latest version of the artifact. +4. **Artifact Version ID Provided**: If an artifact version ID is provided +directly, ZenML will use it to fetch and attach the metadata to that +specific artifact version. + ## Fetching logged metadata -Once metadata has been logged in an artifact, or [step](../track-metrics-metadata/attach-metadata-to-a-model.md), we can easily fetch the metadata with the ZenML Client: +Once metadata has been logged in an artifact, or +[step](../track-metrics-metadata/attach-metadata-to-a-model.md), we can easily +fetch the metadata with the ZenML Client: ```python from zenml.client import Client @@ -50,19 +84,24 @@ from zenml.client import Client client = Client() artifact = client.get_artifact_version("my_artifact", "my_version") -print(artifact.run_metadata["metadata_key"].value) +print(artifact.run_metadata["metadata_key"]) ``` ## Grouping Metadata in the Dashboard -When logging metadata passing a dictionary of dictionaries in the `metadata` parameter will group the metadata into cards in the ZenML dashboard. This feature helps organize metadata into logical sections, making it easier to visualize and understand. +When logging metadata passing a dictionary of dictionaries in the `metadata` +parameter will group the metadata into cards in the ZenML dashboard. This +feature helps organize metadata into logical sections, making it easier to +visualize and understand. Here's an example of grouping metadata into cards: ```python +from zenml import log_metadata + from zenml.metadata.metadata_types import StorageSize -log_artifact_metadata( +log_metadata( metadata={ "model_metrics": { "accuracy": 0.95, @@ -73,12 +112,13 @@ log_artifact_metadata( "dataset_size": StorageSize(1500000), "feature_columns": ["age", "income", "score"] } - } + }, + artifact_name="my_artifact", ) ``` -In the ZenML dashboard, "model\_metrics" and "data\_details" would appear as separate cards, each containing their respective key-value pairs. - -
ZenML Scarf
- +In the ZenML dashboard, "model_metrics" and "data_details" would appear as +separate cards, each containing their respective key-value pairs. + +
ZenML Scarf
\ No newline at end of file diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-steps.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-steps.md deleted file mode 100644 index ee2573720ae..00000000000 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-steps.md +++ /dev/null @@ -1,65 +0,0 @@ -# Attach metadata to steps - -You might want to log metadata and have that be attached to a specific step during the course of your work. This is possible by using the `log_step_metadata` method. This method allows you to attach a dictionary of key-value pairs as metadata to a step. The metadata can be any JSON-serializable value, including custom classes such as `Uri`, `Path`, `DType`, and `StorageSize`. - -You can call this method from within a step or from outside. If you call it from within it will attach the metadata to the step and run that is currently being executed. - -```python -from zenml import step, log_step_metadata, ArtifactConfig, get_step_context -from typing import Annotated -import pandas as pd -from sklearn.ensemble import RandomForestClassifier -from sklearn.base import ClassifierMixin - -@step -def train_model(dataset: pd.DataFrame) -> Annotated[ClassifierMixin, ArtifactConfig(name="sklearn_classifier", is_model_artifact=True)]: - """Train a model""" - # Fit the model and compute metrics - classifier = RandomForestClassifier().fit(dataset) - accuracy, precision, recall = ... - - # Log metadata at the step level - # This associates the metadata with the ZenML step run - log_step_metadata( - metadata={ - "evaluation_metrics": { - "accuracy": accuracy, - "precision": precision, - "recall": recall - } - }, - ) - return classifier -``` - -If you call it from outside you can attach the metadata to a specific step run from any pipeline and step. This is useful if you want to attach the metadata after you've run the step. - -```python -from zenml import log_step_metadata -# run some step - -# subsequently log the metadata for the step -log_step_metadata( - metadata={ - "some_metadata": {"a_number": 3} - }, - pipeline_name_id_or_prefix="my_pipeline", - step_name="my_step", - run_id="my_step_run_id" -) -``` - -## Fetching logged metadata - -Once metadata has been logged in an [artifact](attach-metadata-to-an-artifact.md), [model](attach-metadata-to-a-model.md), we can easily fetch the metadata with the ZenML Client: - -```python -from zenml.client import Client - -client = Client() -step = client.get_pipeline_run().steps["step_name"] - -print(step.run_metadata["metadata_key"].value) -``` - -
ZenML Scarf
diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/fetch-metadata-within-steps.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/fetch-metadata-within-steps.md index 25ee26f2095..d57f523483d 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/fetch-metadata-within-steps.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/fetch-metadata-within-steps.md @@ -11,6 +11,7 @@ To find information about the pipeline or step that is currently running, you ca ```python from zenml import step, get_step_context + @step def my_step(): step_context = get_step_context() @@ -19,9 +20,12 @@ def my_step(): step_name = step_context.step_run.name ``` -Furthermore, you can also use the `StepContext` to find out where the outputs of your current step will be stored and which [Materializer](../handle-data-artifacts/handle-custom-data-types.md) class will be used to save them: +Furthermore, you can also use the `StepContext` to find out where the outputs of your current step will be stored and which [Materializer](../../data-artifact-management/handle-data-artifacts/handle-custom-data-types.md) class will be used to save them: ```python +from zenml import step, get_step_context + + @step def my_step(): step_context = get_step_context() diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/grouping-metadata.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/grouping-metadata.md index 52838875085..e90400f96c5 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/grouping-metadata.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/grouping-metadata.md @@ -4,16 +4,20 @@ description: Learn how to group key-value pairs in the dashboard. # Grouping Metadata in the Dashboard -![Metadata in the dashboard](../../.gitbook/assets/metadata-in-dashboard.png) +![Metadata in the dashboard](../../../.gitbook/assets/metadata-in-dashboard.png) -When logging metadata passing a dictionary of dictionaries in the `metadata` parameter will group the metadata into cards in the ZenML dashboard. This feature helps organize metadata into logical sections, making it easier to visualize and understand. +When logging metadata passing a dictionary of dictionaries in the +`metadata` parameter will group the metadata into cards in the ZenML dashboard. +This feature helps organize metadata into logical sections, making it +easier to visualize and understand. Here's an example of grouping metadata into cards: ```python +from zenml import log_metadata from zenml.metadata.metadata_types import StorageSize -log_artifact_metadata( +log_metadata( metadata={ "model_metrics": { "accuracy": 0.95, @@ -24,11 +28,15 @@ log_artifact_metadata( "dataset_size": StorageSize(1500000), "feature_columns": ["age", "income", "score"] } - } + }, + artifact_name="my_artifact", + artifact_version="my_artifact_version", ) ``` -In the ZenML dashboard, "model\_metrics" and "data\_details" would appear as separate cards, each containing their respective key-value pairs. +In the ZenML dashboard, "model_metrics" and "data_details" would appear +as separate cards, each containing their respective key-value pairs. +
ZenML Scarf
diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/logging-metadata.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/logging-metadata.md index f7d4c67c199..63501056a3e 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/logging-metadata.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/logging-metadata.md @@ -4,13 +4,15 @@ description: Tracking your metadata. # Special Metadata Types -ZenML supports several special metadata types to capture specific kinds of information. Here are examples of how to use the special types `Uri`, `Path`, `DType`, and `StorageSize`: +ZenML supports several special metadata types to capture specific kinds of +information. Here are examples of how to use the special types `Uri`, `Path`, +`DType`, and `StorageSize`: ```python -from zenml.metadata.metadata_types import StorageSize, DType -from zenml import log_artifact_metadata +from zenml import log_metadata +from zenml.metadata.metadata_types import StorageSize, DType, Uri, Path -log_artifact_metadata( +log_metadata( metadata={ "dataset_source": Uri("gs://my-bucket/datasets/source.csv"), "preprocessing_script": Path("/scripts/preprocess.py"), @@ -20,7 +22,8 @@ log_artifact_metadata( "score": DType("int") }, "processed_data_size": StorageSize(2500000) - } + }, + artifact_name="my_artifact", ) ``` @@ -31,6 +34,7 @@ In this example: * `DType` is used to describe the data types of specific columns. * `StorageSize` is used to indicate the size of the processed data in bytes. -These special types help standardize the format of metadata and ensure that it is logged in a consistent and interpretable manner. +These special types help standardize the format of metadata and ensure that it +is logged in a consistent and interpretable manner.
ZenML Scarf
diff --git a/docs/book/how-to/pipeline-development/build-pipelines/README.md b/docs/book/how-to/pipeline-development/build-pipelines/README.md index 28985f09f7f..7ea75d73ba4 100644 --- a/docs/book/how-to/pipeline-development/build-pipelines/README.md +++ b/docs/book/how-to/pipeline-development/build-pipelines/README.md @@ -46,6 +46,6 @@ locally or remotely. See our documentation on this [here](../../getting-started/ Check below for more advanced ways to build and interact with your pipeline. -
Configure pipeline/step parametersuse-pipeline-step-parameters.md
Name and annotate step outputsstep-output-typing-and-annotation.md
Control caching behaviorcontrol-caching-behavior.md
Run pipeline from a pipelinetrigger-a-pipeline-from-another.md
Control the execution order of stepscontrol-execution-order-of-steps.md
Customize the step invocation idsusing-a-custom-step-invocation-id.md
Name your pipeline runsname-your-pipeline-and-runs.md
Use failure/success hooksuse-failure-success-hooks.md
Hyperparameter tuninghyper-parameter-tuning.md
Attach metadata to stepsattach-metadata-to-steps.md
Fetch metadata within stepsfetch-metadata-within-steps.md
Fetch metadata during pipeline compositionfetch-metadata-within-pipeline.md
Enable or disable logs storingenable-or-disable-logs-storing.md
Special Metadata Typeslogging-metadata.md
Access secrets in a stepaccess-secrets-in-a-step.md
+
Configure pipeline/step parametersuse-pipeline-step-parameters.md
Name and annotate step outputsstep-output-typing-and-annotation.md
Control caching behaviorcontrol-caching-behavior.md
Run pipeline from a pipelinetrigger-a-pipeline-from-another.md
Control the execution order of stepscontrol-execution-order-of-steps.md
Customize the step invocation idsusing-a-custom-step-invocation-id.md
Name your pipeline runsname-your-pipeline-and-runs.md
Use failure/success hooksuse-failure-success-hooks.md
Hyperparameter tuninghyper-parameter-tuning.md
Attach metadata to a stepattach-metadata-to-a-step.md
Fetch metadata within stepsfetch-metadata-within-steps.md
Fetch metadata during pipeline compositionfetch-metadata-within-pipeline.md
Enable or disable logs storingenable-or-disable-logs-storing.md
Special Metadata Typeslogging-metadata.md
Access secrets in a stepaccess-secrets-in-a-step.md
ZenML Scarf
diff --git a/docs/book/toc.md b/docs/book/toc.md index 7ee4aea6e65..e8c61c33932 100644 --- a/docs/book/toc.md +++ b/docs/book/toc.md @@ -149,9 +149,10 @@ * [Linking model binaries/data to a Model](how-to/model-management-metrics/model-control-plane/linking-model-binaries-data-to-models.md) * [Load artifacts from Model](how-to/model-management-metrics/model-control-plane/load-artifacts-from-model.md) * [Track metrics and metadata](how-to/model-management-metrics/track-metrics-metadata/README.md) - * [Attach metadata to a model](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md) + * [Attach metadata to a step](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md) + * [Attach metadata to a run](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md) * [Attach metadata to an artifact](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md) - * [Attach metadata to steps](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-steps.md) + * [Attach metadata to a model](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md) * [Group metadata](how-to/model-management-metrics/track-metrics-metadata/grouping-metadata.md) * [Special Metadata Types](how-to/model-management-metrics/track-metrics-metadata/logging-metadata.md) * [Fetch metadata within steps](how-to/model-management-metrics/track-metrics-metadata/fetch-metadata-within-steps.md) From b767269d564e37fd42f3a992a4b48c87d1884643 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 13 Nov 2024 19:59:32 +0100 Subject: [PATCH 058/124] adding a parameter to control the related entity behaviour --- src/zenml/utils/metadata_utils.py | 57 +++++++++++++------ .../schemas/run_metadata_schemas.py | 2 +- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 79187941e1f..73cfabd6573 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -27,7 +27,11 @@ @overload -def log_metadata(metadata: Dict[str, MetadataType]) -> None: ... +def log_metadata( + *, + metadata: Dict[str, MetadataType], + log_related_entities: Optional[bool] = True, +) -> None: ... @overload @@ -35,6 +39,7 @@ def log_metadata( *, metadata: Dict[str, MetadataType], artifact_version_id: UUID, + log_related_entities: Optional[bool] = True, ) -> None: ... @@ -44,6 +49,7 @@ def log_metadata( metadata: Dict[str, MetadataType], artifact_name: str, artifact_version: Optional[str] = None, + log_related_entities: Optional[bool] = True, ) -> None: ... @@ -52,6 +58,7 @@ def log_metadata( *, metadata: Dict[str, MetadataType], model_version_id: UUID, + log_related_entities: Optional[bool] = True, ) -> None: ... @@ -61,6 +68,7 @@ def log_metadata( metadata: Dict[str, MetadataType], model_name: str, model_version: str, + log_related_entities: Optional[bool] = True, ) -> None: ... @@ -69,6 +77,7 @@ def log_metadata( *, metadata: Dict[str, MetadataType], step_id: UUID, + log_related_entities: Optional[bool] = True, ) -> None: ... @@ -77,6 +86,7 @@ def log_metadata( *, metadata: Dict[str, MetadataType], run_id_name_or_prefix: Union[UUID, str], + log_related_entities: Optional[bool] = True, ) -> None: ... @@ -86,6 +96,7 @@ def log_metadata( metadata: Dict[str, MetadataType], step_name: str, run_id_name_or_prefix: Union[UUID, str], + log_related_entities: Optional[bool] = True, ) -> None: ... @@ -103,6 +114,8 @@ def log_metadata( model_version_id: Optional[UUID] = None, model_name: Optional[str] = None, model_version: Optional[str] = None, + # Parameter to adjust whether we log to all related entities + log_related_entities: Optional[bool] = True, ) -> None: """Logs metadata for various resource types in a generalized way. @@ -117,6 +130,8 @@ def log_metadata( model_version_id: The ID of the model version. model_name: The name of the model. model_version: The version of the model + log_related_entities: Flag to decide whether we should log the same + metadata for related entities. Raises: ValueError: If no identifiers are provided and the function is not @@ -130,29 +145,37 @@ def log_metadata( run = client.get_pipeline_run(run_id_name_or_prefix) step = run.steps[step_name] - resources = [ - (run.id, MetadataResourceTypes.PIPELINE_RUN), - (step.id, MetadataResourceTypes.STEP_RUN), - ] - if step.model_version: - resources.append( - (step.model_version.id, MetadataResourceTypes.MODEL_VERSION) - ) + resources = [(step.id, MetadataResourceTypes.STEP_RUN)] + + if log_related_entities: + resources.append((run.id, MetadataResourceTypes.PIPELINE_RUN)) + if step.model_version: + resources.append( + ( + step.model_version.id, + MetadataResourceTypes.MODEL_VERSION, + ) + ) client.create_run_metadata(metadata=metadata, resources=resources) # If a step is identified by id, fetch it directly through the client, # follow a similar procedure and log metadata for its pipeline and model # as well. elif step_id is not None: - step = client.get_run_step(step_id) + resources = [(step_id, MetadataResourceTypes.STEP_RUN)] - resources = [ - (step.pipeline_run_id, MetadataResourceTypes.PIPELINE_RUN), - (step.id, MetadataResourceTypes.STEP_RUN), - ] - if step.model_version: + if log_related_entities: + step = client.get_run_step(step_id) resources.append( - (step.model_version.id, MetadataResourceTypes.MODEL_VERSION) + (step.pipeline_run_id, MetadataResourceTypes.PIPELINE_RUN) ) + + if step.model_version: + resources.append( + ( + step.model_version.id, + MetadataResourceTypes.MODEL_VERSION, + ) + ) client.create_run_metadata(metadata=metadata, resources=resources) # If a pipeline run id is identified, we need to log metadata to it and its @@ -162,7 +185,7 @@ def log_metadata( resources = [(run.id, MetadataResourceTypes.PIPELINE_RUN)] - if run.model_version: + if log_related_entities and run.model_version is not None: resources.append( (run.model_version.id, MetadataResourceTypes.MODEL_VERSION) ) diff --git a/src/zenml/zen_stores/schemas/run_metadata_schemas.py b/src/zenml/zen_stores/schemas/run_metadata_schemas.py index b3c3d713996..c927fd281f2 100644 --- a/src/zenml/zen_stores/schemas/run_metadata_schemas.py +++ b/src/zenml/zen_stores/schemas/run_metadata_schemas.py @@ -106,7 +106,7 @@ class RunMetadataResourceSchema(SQLModel, table=True): sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)", overlaps="run_metadata,step_run,artifact_version,model_version", - ) + ), ) step_run: List["StepRunSchema"] = Relationship( back_populates="run_metadata", From 3b1ee3acd43454dfcf1f3479639af5be474ec0b2 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 13 Nov 2024 20:01:58 +0100 Subject: [PATCH 059/124] fixing the toc --- docs/book/toc.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/book/toc.md b/docs/book/toc.md index e8c61c33932..f950f5cda1e 100644 --- a/docs/book/toc.md +++ b/docs/book/toc.md @@ -150,7 +150,7 @@ * [Load artifacts from Model](how-to/model-management-metrics/model-control-plane/load-artifacts-from-model.md) * [Track metrics and metadata](how-to/model-management-metrics/track-metrics-metadata/README.md) * [Attach metadata to a step](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md) - * [Attach metadata to a run](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md) + * [Attach metadata to a run](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md) * [Attach metadata to an artifact](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md) * [Attach metadata to a model](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md) * [Group metadata](how-to/model-management-metrics/track-metrics-metadata/grouping-metadata.md) From 6f2e2245a39704e9a4ed24540b920dc97d00f0f2 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 13 Nov 2024 20:12:36 +0100 Subject: [PATCH 060/124] fixed the description --- .../track-metrics-metadata/attach-metadata-to-a-run.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md index 525b379e5b0..1e20f383fca 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md @@ -1,5 +1,5 @@ --- -description: Attaching metadata to a run. +description: Learn how to attach metadata to a run. --- # Attach Metadata to a Run From bb21a073db1e6643d33fd62405aa52b89e7ec2c4 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 13 Nov 2024 20:57:28 +0100 Subject: [PATCH 061/124] docstring --- src/zenml/client.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index 684e3f33ac4..a7a8e502a32 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -4446,9 +4446,6 @@ def create_run_metadata( metadata was produced. stack_component_id: The ID of the stack component that produced the metadata. - - Returns: - None """ from zenml.metadata.metadata_types import get_metadata_type From 791ddc037f16659f9742b8f01dde30eafb8d4cc5 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 13 Nov 2024 20:58:05 +0100 Subject: [PATCH 062/124] spellcheck --- .../track-metrics-metadata/attach-metadata-to-a-run.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md index 1e20f383fca..5d1495b79c6 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md @@ -52,7 +52,7 @@ def train_model(dataset: pd.DataFrame) -> Annotated[ {% hint style="warning" %} In order to log metadata to a pipeline run during the step execution without specifying any additional identifiers, `log_related_entities` should be -`True` (default behaviour). +`True` (default behavior). {% endhint %} ## Logging Metadata Outside a Run From bf771f70f783a5f38651059e83ec9441ba16864a Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 13 Nov 2024 22:39:00 +0100 Subject: [PATCH 063/124] metadata creation during artifact version creation --- src/zenml/models/v2/core/run_metadata.py | 3 +- src/zenml/zen_stores/sql_zen_store.py | 46 +++++++++++++++++++----- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index da395ab0e6c..c11597699a3 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -34,7 +34,8 @@ class RunMetadataRequest(WorkspaceScopedRequest): title="The list of resources that this metadata belongs to." ) stack_component_id: Optional[UUID] = Field( - title="The ID of the stack component that this metadata belongs to." + title="The ID of the stack component that this metadata belongs to.", + default=None, ) values: Dict[str, "MetadataType"] = Field( title="The metadata to be created.", diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 1d7ad2603cc..c2f743d40b8 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -26,6 +26,7 @@ from functools import lru_cache from pathlib import Path from typing import ( + TYPE_CHECKING, Any, Callable, ClassVar, @@ -355,6 +356,9 @@ SqlSecretsStoreConfiguration, ) +if TYPE_CHECKING: + from zenml.metadata.metadata_types import MetadataType, MetadataTypeEnum + AnyNamedSchema = TypeVar("AnyNamedSchema", bound=NamedSchema) AnySchema = TypeVar("AnySchema", bound=BaseSchema) @@ -2916,17 +2920,41 @@ def create_artifact_version( # Save metadata of the artifact if artifact_version.metadata: + values: Dict[str, "MetadataType"] = {} + types: Dict[str, "MetadataTypeEnum"] = {} for key, value in artifact_version.metadata.items(): - run_metadata_schema = RunMetadataSchema( - workspace_id=artifact_version.workspace, - user_id=artifact_version.user, - resource_id=artifact_version_id, - resource_type=MetadataResourceTypes.ARTIFACT_VERSION, - key=key, - value=json.dumps(value), - type=get_metadata_type(value), + # Skip metadata that is too large to be stored in the DB. + if len(json.dumps(value)) > TEXT_FIELD_MAX_LENGTH: + logger.warning( + f"Metadata value for key '{key}' is too large to be " + "stored in the database. Skipping." + ) + continue + # Skip metadata that is not of a supported type. + try: + metadata_type = get_metadata_type(value) + except ValueError as e: + logger.warning( + f"Metadata value for key '{key}' is not of a " + f"supported type. Skipping. Full error: {e}" + ) + continue + values[key] = value + types[key] = metadata_type + self.create_run_metadata( + RunMetadataRequest( + workspace=artifact_version.workspace, + user=artifact_version.user, + resources=[ + ( + artifact_version_id, + MetadataResourceTypes.ARTIFACT_VERSION, + ) + ], + values=values, + types=types, ) - session.add(run_metadata_schema) + ) session.commit() artifact_version_schema = session.exec( From 9ceeddecf6898506283d5c3aaddf36b4794c2c95 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 13 Nov 2024 22:47:14 +0100 Subject: [PATCH 064/124] allowing artifact metadata with name for external artifact --- src/zenml/utils/metadata_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 73cfabd6573..213be8d81ec 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -238,7 +238,7 @@ def log_metadata( with contextlib.suppress(RuntimeError): step_context = get_step_context() - if step_context: + if step_context and artifact_name in step_context._outputs: step_context.add_output_metadata( metadata=metadata, output_name=artifact_name ) From 52fdba45867859410dfea79dff5de6bb4495e488 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 13 Nov 2024 23:01:08 +0100 Subject: [PATCH 065/124] update the template versions --- .github/workflows/update-templates-to-examples.yml | 8 ++++---- src/zenml/cli/base.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/update-templates-to-examples.yml b/.github/workflows/update-templates-to-examples.yml index 153f1a1aca2..f58b2a9424f 100644 --- a/.github/workflows/update-templates-to-examples.yml +++ b/.github/workflows/update-templates-to-examples.yml @@ -46,7 +46,7 @@ jobs: python-version: ${{ inputs.python-version }} stack-name: local ref-zenml: ${{ github.ref }} - ref-template: 2024.10.30 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py + ref-template: 2024.11.13 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py - name: Clean-up run: | rm -rf ./local_checkout @@ -118,7 +118,7 @@ jobs: python-version: ${{ inputs.python-version }} stack-name: local ref-zenml: ${{ github.ref }} - ref-template: 2024.10.30 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py + ref-template: 2024.11.13 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py - name: Clean-up run: | rm -rf ./local_checkout @@ -189,7 +189,7 @@ jobs: python-version: ${{ inputs.python-version }} stack-name: local ref-zenml: ${{ github.ref }} - ref-template: 2024.10.30 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py + ref-template: 2024.11.13 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py - name: Clean-up run: | rm -rf ./local_checkout @@ -261,7 +261,7 @@ jobs: with: python-version: ${{ inputs.python-version }} ref-zenml: ${{ github.ref }} - ref-template: 2024.11.08 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py + ref-template: 2024.11.13 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py - name: Clean-up run: | rm -rf ./local_checkout diff --git a/src/zenml/cli/base.py b/src/zenml/cli/base.py index 7d429c8e701..1d5adecae0b 100644 --- a/src/zenml/cli/base.py +++ b/src/zenml/cli/base.py @@ -79,19 +79,19 @@ def copier_github_url(self) -> str: ZENML_PROJECT_TEMPLATES = dict( e2e_batch=ZenMLProjectTemplateLocation( github_url="zenml-io/template-e2e-batch", - github_tag="2024.10.30", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + github_tag="2024.11.13", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), starter=ZenMLProjectTemplateLocation( github_url="zenml-io/template-starter", - github_tag="2024.10.30", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + github_tag="2024.11.13", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), nlp=ZenMLProjectTemplateLocation( github_url="zenml-io/template-nlp", - github_tag="2024.10.30", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + github_tag="2024.11.13", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), llm_finetuning=ZenMLProjectTemplateLocation( github_url="zenml-io/template-llm-finetuning", - github_tag="2024.11.08", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + github_tag="2024.11.13", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), ) From 1855edf99b24a5fb64dfcdb83419ac05e51acde4 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Wed, 13 Nov 2024 22:22:02 +0000 Subject: [PATCH 066/124] Auto-update of LLM Finetuning template --- examples/llm_finetuning/.copier-answers.yml | 2 +- examples/llm_finetuning/steps/log_metadata.py | 9 +++++++-- .../llm_finetuning/steps/prepare_datasets.py | 18 +++++++++++------- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/llm_finetuning/.copier-answers.yml b/examples/llm_finetuning/.copier-answers.yml index 2c547f98d61..250f3b832e8 100644 --- a/examples/llm_finetuning/.copier-answers.yml +++ b/examples/llm_finetuning/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.11.08 +_commit: 2024.11.08-1-gd399790 _src_path: gh:zenml-io/template-llm-finetuning bf16: true cuda_version: cuda11.8 diff --git a/examples/llm_finetuning/steps/log_metadata.py b/examples/llm_finetuning/steps/log_metadata.py index 645f98cc8ea..90109fdf3c4 100644 --- a/examples/llm_finetuning/steps/log_metadata.py +++ b/examples/llm_finetuning/steps/log_metadata.py @@ -17,7 +17,7 @@ from typing import Any, Dict -from zenml import get_step_context, log_model_metadata, step +from zenml import get_step_context, log_metadata, step @step(enable_cache=False) @@ -39,4 +39,9 @@ def log_metadata_from_step_artifact( metadata = {artifact_name: metadata_dict} - log_model_metadata(metadata) + if context.model: + log_metadata( + metadata=metadata, + model_name=context.model.name, + model_version=context.model.version, + ) diff --git a/examples/llm_finetuning/steps/prepare_datasets.py b/examples/llm_finetuning/steps/prepare_datasets.py index fe98126369d..00086bcdaf8 100644 --- a/examples/llm_finetuning/steps/prepare_datasets.py +++ b/examples/llm_finetuning/steps/prepare_datasets.py @@ -22,7 +22,7 @@ from typing_extensions import Annotated from utils.tokenizer import generate_and_tokenize_prompt, load_tokenizer -from zenml import log_model_metadata, step +from zenml import get_step_context, log_metadata, step from zenml.materializers import BuiltInMaterializer from zenml.utils.cuda_utils import cleanup_gpu_memory @@ -49,12 +49,16 @@ def prepare_data( cleanup_gpu_memory(force=True) - log_model_metadata( - { - "system_prompt": system_prompt, - "base_model_id": base_model_id, - } - ) + context = get_step_context() + if context.model: + log_metadata( + metadata={ + "system_prompt": system_prompt, + "base_model_id": base_model_id, + }, + model_name=context.model.name, + model_version=context.model.version, + ) tokenizer = load_tokenizer(base_model_id, False, use_fast) gen_and_tokenize = partial( From a2462b5e4e735e0c66ea74bc2ec618e595aba313 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Wed, 13 Nov 2024 22:22:31 +0000 Subject: [PATCH 067/124] Auto-update of Starter template --- examples/mlops_starter/.copier-answers.yml | 2 +- examples/mlops_starter/steps/data_preprocessor.py | 6 +++--- examples/mlops_starter/steps/model_evaluator.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/mlops_starter/.copier-answers.yml b/examples/mlops_starter/.copier-answers.yml index fd6b937c7c9..1c65d17e37c 100644 --- a/examples/mlops_starter/.copier-answers.yml +++ b/examples/mlops_starter/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.10.30 +_commit: 2024.10.30-3-g52bf387 _src_path: gh:zenml-io/template-starter email: info@zenml.io full_name: ZenML GmbH diff --git a/examples/mlops_starter/steps/data_preprocessor.py b/examples/mlops_starter/steps/data_preprocessor.py index 0cf9d3ab521..f20cd93aa13 100644 --- a/examples/mlops_starter/steps/data_preprocessor.py +++ b/examples/mlops_starter/steps/data_preprocessor.py @@ -23,7 +23,7 @@ from typing_extensions import Annotated from utils.preprocess import ColumnsDropper, DataFrameCaster, NADropper -from zenml import log_artifact_metadata, step +from zenml import log_metadata, step @step @@ -87,8 +87,8 @@ def data_preprocessor( dataset_tst = preprocess_pipeline.transform(dataset_tst) # Log metadata so we can load it in the inference pipeline - log_artifact_metadata( - artifact_name="preprocess_pipeline", + log_metadata( metadata={"random_state": random_state, "target": target}, + artifact_name="preprocess_pipeline", ) return dataset_trn, dataset_tst, preprocess_pipeline diff --git a/examples/mlops_starter/steps/model_evaluator.py b/examples/mlops_starter/steps/model_evaluator.py index 2a9b6ee9e75..a771d2fdd76 100644 --- a/examples/mlops_starter/steps/model_evaluator.py +++ b/examples/mlops_starter/steps/model_evaluator.py @@ -20,7 +20,7 @@ import pandas as pd from sklearn.base import ClassifierMixin -from zenml import log_artifact_metadata, step +from zenml import log_metadata, step from zenml.logger import get_logger logger = get_logger(__name__) @@ -95,7 +95,7 @@ def model_evaluator( for message in messages: logger.warning(message) - log_artifact_metadata( + log_metadata( metadata={ "train_accuracy": float(trn_acc), "test_accuracy": float(tst_acc), From 4283966e4372d85cc34da2c5254abd93d146d0e2 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Wed, 13 Nov 2024 22:29:01 +0000 Subject: [PATCH 068/124] Auto-update of E2E template --- examples/e2e/.copier-answers.yml | 2 +- examples/e2e/steps/hp_tuning/hp_tuning_select_best_model.py | 2 +- examples/e2e/steps/hp_tuning/hp_tuning_single_search.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/e2e/.copier-answers.yml b/examples/e2e/.copier-answers.yml index cd687be59df..c5e21b74fcb 100644 --- a/examples/e2e/.copier-answers.yml +++ b/examples/e2e/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.10.30 +_commit: 2024.10.30-3-gd8d1576 _src_path: gh:zenml-io/template-e2e-batch data_quality_checks: true email: info@zenml.io diff --git a/examples/e2e/steps/hp_tuning/hp_tuning_select_best_model.py b/examples/e2e/steps/hp_tuning/hp_tuning_select_best_model.py index 65e524ecd98..1fc9a7cdc79 100644 --- a/examples/e2e/steps/hp_tuning/hp_tuning_select_best_model.py +++ b/examples/e2e/steps/hp_tuning/hp_tuning_select_best_model.py @@ -47,7 +47,7 @@ def hp_tuning_select_best_model( best_metric = -1 # consume artifacts attached to current model version in Model Control Plane for step_name in step_names: - hp_output = model.get_data_artifact("hp_result") + hp_output = model.get_artifact("hp_result") model_: ClassifierMixin = hp_output.load() # fetch metadata we attached earlier metric = float(hp_output.run_metadata["metric"]) diff --git a/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py b/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py index f2f39969a6f..7948a011e7b 100644 --- a/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py +++ b/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py @@ -25,7 +25,7 @@ from typing_extensions import Annotated from utils import get_model_from_config -from zenml import log_artifact_metadata, step +from zenml import log_metadata, step from zenml.logger import get_logger logger = get_logger(__name__) @@ -95,7 +95,7 @@ def hp_tuning_single_search( y_pred = cv.predict(X_tst) score = accuracy_score(y_tst, y_pred) # log score along with output artifact as metadata - log_artifact_metadata( + log_metadata( metadata={"metric": float(score)}, artifact_name="hp_result", ) From a915e62ad5550e1cc54b103c490cff13849a5ae6 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Wed, 13 Nov 2024 22:31:51 +0000 Subject: [PATCH 069/124] Auto-update of NLP template --- examples/e2e_nlp/.copier-answers.yml | 2 +- examples/e2e_nlp/steps/training/model_trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/e2e_nlp/.copier-answers.yml b/examples/e2e_nlp/.copier-answers.yml index e13858e7da1..33820b0a2d2 100644 --- a/examples/e2e_nlp/.copier-answers.yml +++ b/examples/e2e_nlp/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.10.30 +_commit: 2024.10.30-1-g8d87577 _src_path: gh:zenml-io/template-nlp accelerator: cpu cloud_of_choice: aws diff --git a/examples/e2e_nlp/steps/training/model_trainer.py b/examples/e2e_nlp/steps/training/model_trainer.py index edb9ab23ba5..812fe712ee4 100644 --- a/examples/e2e_nlp/steps/training/model_trainer.py +++ b/examples/e2e_nlp/steps/training/model_trainer.py @@ -30,7 +30,7 @@ from typing_extensions import Annotated from utils.misc import compute_metrics -from zenml import ArtifactConfig, log_artifact_metadata, step +from zenml import ArtifactConfig, log_metadata, step from zenml.client import Client from zenml.integrations.mlflow.experiment_trackers import ( MLFlowExperimentTracker, @@ -157,7 +157,7 @@ def model_trainer( eval_results = trainer.evaluate(metric_key_prefix="") # Log the evaluation results in model control plane - log_artifact_metadata( + log_metadata( metadata={"metrics": eval_results}, artifact_name="model", ) From f679fea64e5ed1399a546d390d683dd0c0ed64f0 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 14 Nov 2024 00:54:11 +0100 Subject: [PATCH 070/124] fixing the migration script --- .../cc269488e5a9_separate_run_metadata.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py index 6af578a7fef..6de54c6b823 100644 --- a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py +++ b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py @@ -1,5 +1,4 @@ """Separate run metadata into resource link table with new UUIDs. - Revision ID: cc269488e5a9 Revises: 0.70.0 Create Date: 2024-11-12 09:46:46.587478 @@ -33,7 +32,7 @@ def upgrade() -> None: sa.Column("resource_type", sa.String(length=255), nullable=False), sa.Column( "run_metadata_id", - sa.Integer, + sqlmodel.sql.sqltypes.GUID(), sa.ForeignKey("run_metadata.id", ondelete="CASCADE"), nullable=False, ), @@ -62,18 +61,25 @@ def upgrade() -> None: ] # Perform bulk insert into `run_metadata_resource` - op.bulk_insert( - sa.table( - "run_metadata_resource", - sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), - sa.Column( - "resource_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + if resource_data: # Only perform insert if there's data to migrate + op.bulk_insert( + sa.table( + "run_metadata_resource", + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column( + "resource_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column( + "resource_type", sa.String(length=255), nullable=False + ), + sa.Column( + "run_metadata_id", + sqlmodel.sql.sqltypes.GUID(), + nullable=False, + ), # Changed to BIGINT ), - sa.Column("resource_type", sa.String(length=255), nullable=False), - sa.Column("run_metadata_id", sa.Integer, nullable=False), - ), - resource_data, - ) + resource_data, + ) # Drop the old `resource_id` and `resource_type` columns from `run_metadata` op.drop_column("run_metadata", "resource_id") From 64b0a0bf3f621e182b69c9d3d57ffd544d607ac6 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 14 Nov 2024 00:57:47 +0100 Subject: [PATCH 071/124] formatting --- .../migrations/versions/cc269488e5a9_separate_run_metadata.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py index 6de54c6b823..c345b3ffe4d 100644 --- a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py +++ b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py @@ -1,4 +1,5 @@ """Separate run metadata into resource link table with new UUIDs. + Revision ID: cc269488e5a9 Revises: 0.70.0 Create Date: 2024-11-12 09:46:46.587478 From df0fbda327d0577c170d721f4fb6b51e9d8424d0 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Mon, 18 Nov 2024 09:31:26 +0100 Subject: [PATCH 072/124] redirects --- .gitbook.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitbook.yaml b/.gitbook.yaml index 62e711dd848..4bba8d2a4f0 100644 --- a/.gitbook.yaml +++ b/.gitbook.yaml @@ -18,6 +18,7 @@ redirects: how-to/setting-up-a-project-repository/best-practices: how-to/project-setup-and-management/setting-up-a-project-repository/set-up-repository.md getting-started/zenml-pro/system-architectures: getting-started/system-architectures.md how-to/build-pipelines/name-your-pipeline-and-runs: how-to/pipeline-development/build-pipelines/name-your-pipeline-runs.md + how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-steps: how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md # Project Setup redirects how-to/setting-up-a-project-repository/: how-to/project-setup-and-management/setting-up-a-project-repository/README.md From 324e67f3b91ec13828294cb1e605b49f51bbe255 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Mon, 18 Nov 2024 11:12:03 +0100 Subject: [PATCH 073/124] minor fixes --- .../attach-metadata-to-a-model.md | 35 ++++++++++++------- src/zenml/client.py | 2 +- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md index d6a3717daea..a8f85299978 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md @@ -26,7 +26,7 @@ import pandas as pd from sklearn.base import ClassifierMixin from sklearn.ensemble import RandomForestClassifier -from zenml import step, log_metadata, ArtifactConfig +from zenml import step, log_metadata, ArtifactConfig, get_step_context @step @@ -38,19 +38,23 @@ def train_model(dataset: pd.DataFrame) -> Annotated[ """Train a model and log model metadata.""" classifier = RandomForestClassifier().fit(dataset) accuracy, precision, recall = ... + + step_context = get_step_context() + + if step_context.model: + # Log metadata for the model + log_metadata( + metadata={ + "evaluation_metrics": { + "accuracy": accuracy, + "precision": precision, + "recall": recall + } + }, + model_name=step_context.model.name, + model_version=step_context.model.version, + ) - # Log metadata for the model - log_metadata( - metadata={ - "evaluation_metrics": { - "accuracy": accuracy, - "precision": precision, - "recall": recall - } - }, - model_name="zenml_model_name", - model_version="zenml_model_version" - ) return classifier ``` @@ -59,6 +63,11 @@ specific classifier artifact. This is particularly useful when the metadata reflects an aggregation or summary of various steps and artifacts in the pipeline. +{% hint style="info" %} +You can use the `get_step_context()` function to get fetch the model and model +version that the step is using. +{% endhint %} + ### Selecting Models with `log_metadata` When using `log_metadata` with a model, ZenML provides flexible options to diff --git a/src/zenml/client.py b/src/zenml/client.py index a7a8e502a32..3f7b67b0a5f 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -4442,7 +4442,7 @@ def create_run_metadata( Args: metadata: The metadata to create as a dictionary of key-value pairs. - resources: The ID and type of the resources for that the + resources: The list of IDs and types of the resources for that the metadata was produced. stack_component_id: The ID of the stack component that produced the metadata. From f966b0f192051563f22673b066a6ddea178734d9 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 20 Nov 2024 18:02:22 +0100 Subject: [PATCH 074/124] working pipelines again --- src/zenml/artifacts/utils.py | 7 +- src/zenml/client.py | 10 +- src/zenml/model/model.py | 7 +- src/zenml/models/__init__.py | 6 + src/zenml/models/v2/core/run_metadata.py | 9 +- src/zenml/models/v2/misc/run_metadata.py | 36 +++++ src/zenml/orchestrators/publish_utils.py | 13 +- src/zenml/steps/utils.py | 7 +- src/zenml/utils/metadata_utils.py | 150 ++++++++---------- .../routers/workspaces_endpoints.py | 18 +-- .../zen_stores/schemas/artifact_schemas.py | 4 +- src/zenml/zen_stores/schemas/model_schemas.py | 6 +- .../schemas/pipeline_run_schemas.py | 51 +++++- .../schemas/run_metadata_schemas.py | 28 ++-- .../zen_stores/schemas/step_run_schemas.py | 60 ++++++- src/zenml/zen_stores/sql_zen_store.py | 6 +- tests/integration/functional/test_client.py | 18 ++- .../functional/zen_stores/test_zen_store.py | 3 +- 18 files changed, 300 insertions(+), 139 deletions(-) create mode 100644 src/zenml/models/v2/misc/run_metadata.py diff --git a/src/zenml/artifacts/utils.py b/src/zenml/artifacts/utils.py index 76fa22eb0e1..c23e6604acb 100644 --- a/src/zenml/artifacts/utils.py +++ b/src/zenml/artifacts/utils.py @@ -59,6 +59,7 @@ ArtifactVisualizationRequest, LoadedVisualization, PipelineRunResponse, + RunMetadataResource, StepRunResponse, StepRunUpdate, ) @@ -440,7 +441,11 @@ def log_artifact_metadata( response = client.get_artifact_version(artifact_name, artifact_version) client.create_run_metadata( metadata=metadata, - resources=[(response.id, MetadataResourceTypes.ARTIFACT_VERSION)], + resources=[ + RunMetadataResource( + id=response.id, type=MetadataResourceTypes.ARTIFACT_VERSION + ) + ], ) else: diff --git a/src/zenml/client.py b/src/zenml/client.py index 3f7b67b0a5f..f2e74f6cf69 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -60,7 +60,6 @@ from zenml.enums import ( ArtifactType, LogicalOperators, - MetadataResourceTypes, ModelStages, OAuthDeviceStatus, PluginSubType, @@ -137,6 +136,7 @@ PipelineRunFilter, PipelineRunResponse, RunMetadataRequest, + RunMetadataResource, RunTemplateFilter, RunTemplateRequest, RunTemplateResponse, @@ -4435,8 +4435,9 @@ def _delete_artifact_from_artifact_store( def create_run_metadata( self, metadata: Dict[str, "MetadataType"], - resources: List[Tuple[UUID, MetadataResourceTypes]], + resources: List[RunMetadataResource], stack_component_id: Optional[UUID] = None, + publisher_step_id: Optional[UUID] = None, ) -> None: """Create run metadata. @@ -4446,6 +4447,10 @@ def create_run_metadata( metadata was produced. stack_component_id: The ID of the stack component that produced the metadata. + publisher_step_id: The ID of the step which published this metadata. + + Returns: + None """ from zenml.metadata.metadata_types import get_metadata_type @@ -4476,6 +4481,7 @@ def create_run_metadata( user=self.active_user.id, resources=resources, stack_component_id=stack_component_id, + publisher_step_id=publisher_step_id, values=values, types=types, ) diff --git a/src/zenml/model/model.py b/src/zenml/model/model.py index 39db25eb657..0f4330a9a9f 100644 --- a/src/zenml/model/model.py +++ b/src/zenml/model/model.py @@ -337,11 +337,16 @@ def log_metadata( metadata: The metadata to log. """ from zenml.client import Client + from zenml.models import RunMetadataResource response = self._get_or_create_model_version() Client().create_run_metadata( metadata=metadata, - resources=[(response.id, MetadataResourceTypes.MODEL_VERSION)], + resources=[ + RunMetadataResource( + id=response.id, type=MetadataResourceTypes.MODEL_VERSION + ) + ], ) @property diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index 5d6cc0c9125..186880c7c9e 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -368,6 +368,10 @@ OAuthRedirectResponse, OAuthTokenResponse, ) +from zenml.models.v2.misc.run_metadata import ( + RunMetadataEntry, + RunMetadataResource, +) from zenml.models.v2.misc.server_models import ( ServerModel, ServerDatabaseType, @@ -747,4 +751,6 @@ "ServiceConnectorInfo", "ServiceConnectorResourcesInfo", "ResourcesInfo", + "RunMetadataEntry", + "RunMetadataResource", ] diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index c11597699a3..97f966c4cf7 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -13,16 +13,16 @@ # permissions and limitations under the License. """Models representing run metadata.""" -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional from uuid import UUID from pydantic import Field -from zenml.enums import MetadataResourceTypes from zenml.metadata.metadata_types import MetadataType, MetadataTypeEnum from zenml.models.v2.base.scoped import ( WorkspaceScopedRequest, ) +from zenml.models.v2.misc.run_metadata import RunMetadataResource # ------------------ Request Model ------------------ @@ -30,7 +30,7 @@ class RunMetadataRequest(WorkspaceScopedRequest): """Request model for run metadata.""" - resources: List[Tuple[UUID, MetadataResourceTypes]] = Field( + resources: List[RunMetadataResource] = Field( title="The list of resources that this metadata belongs to." ) stack_component_id: Optional[UUID] = Field( @@ -43,3 +43,6 @@ class RunMetadataRequest(WorkspaceScopedRequest): types: Dict[str, "MetadataTypeEnum"] = Field( title="The types of the metadata to be created.", ) + publisher_step_id: Optional[UUID] = Field( + title="The ID of the step who published this metadata." + ) diff --git a/src/zenml/models/v2/misc/run_metadata.py b/src/zenml/models/v2/misc/run_metadata.py new file mode 100644 index 00000000000..5ace56a512d --- /dev/null +++ b/src/zenml/models/v2/misc/run_metadata.py @@ -0,0 +1,36 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel, Field + +from zenml.enums import MetadataResourceTypes +from zenml.metadata.metadata_types import MetadataType + + +class RunMetadataResource(BaseModel): + """Utility class to help identify resources to tag metadata to.""" + + id: UUID = Field(title="The ID of the resource.") + type: MetadataResourceTypes = Field(title="The type of the resource.") + + +class RunMetadataEntry(BaseModel): + """Utility class to sort/list run metadata entries.""" + + value: MetadataType = Field(title="The value for the run metadata entry") + created: datetime = Field( + title="The timestamp when this resource was created." + ) diff --git a/src/zenml/orchestrators/publish_utils.py b/src/zenml/orchestrators/publish_utils.py index a6d864aae32..ed8a01cfae6 100644 --- a/src/zenml/orchestrators/publish_utils.py +++ b/src/zenml/orchestrators/publish_utils.py @@ -21,6 +21,7 @@ from zenml.models import ( PipelineRunResponse, PipelineRunUpdate, + RunMetadataResource, StepRunResponse, StepRunUpdate, ) @@ -129,7 +130,11 @@ def publish_pipeline_run_metadata( for stack_component_id, metadata in pipeline_run_metadata.items(): client.create_run_metadata( metadata=metadata, - resources=[(pipeline_run_id, MetadataResourceTypes.PIPELINE_RUN)], + resources=[ + RunMetadataResource( + id=pipeline_run_id, type=MetadataResourceTypes.PIPELINE_RUN + ) + ], stack_component_id=stack_component_id, ) @@ -149,6 +154,10 @@ def publish_step_run_metadata( for stack_component_id, metadata in step_run_metadata.items(): client.create_run_metadata( metadata=metadata, - resources=[(step_run_id, MetadataResourceTypes.STEP_RUN)], + resources=[ + RunMetadataResource( + id=step_run_id, type=MetadataResourceTypes.STEP_RUN + ) + ], stack_component_id=stack_component_id, ) diff --git a/src/zenml/steps/utils.py b/src/zenml/steps/utils.py index 4cdce1d85e3..c52737f8a77 100644 --- a/src/zenml/steps/utils.py +++ b/src/zenml/steps/utils.py @@ -34,6 +34,7 @@ from zenml.exceptions import StepInterfaceError from zenml.logger import get_logger from zenml.metadata.metadata_types import MetadataType +from zenml.models import RunMetadataResource from zenml.steps.step_context import get_step_context from zenml.utils import settings_utils, source_code_utils, typing_utils @@ -477,7 +478,11 @@ def log_step_metadata( step_run_id = pipeline_run.steps[step_name].id client.create_run_metadata( metadata=metadata, - resources=[(step_run_id, MetadataResourceTypes.STEP_RUN)], + resources=[ + RunMetadataResource( + id=step_run_id, type=MetadataResourceTypes.STEP_RUN + ) + ], ) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 213be8d81ec..c37b6739584 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -21,6 +21,7 @@ from zenml.enums import MetadataResourceTypes from zenml.logger import get_logger from zenml.metadata.metadata_types import MetadataType +from zenml.models import RunMetadataResource from zenml.steps.step_context import get_step_context logger = get_logger(__name__) @@ -30,7 +31,6 @@ def log_metadata( *, metadata: Dict[str, MetadataType], - log_related_entities: Optional[bool] = True, ) -> None: ... @@ -39,7 +39,6 @@ def log_metadata( *, metadata: Dict[str, MetadataType], artifact_version_id: UUID, - log_related_entities: Optional[bool] = True, ) -> None: ... @@ -49,7 +48,6 @@ def log_metadata( metadata: Dict[str, MetadataType], artifact_name: str, artifact_version: Optional[str] = None, - log_related_entities: Optional[bool] = True, ) -> None: ... @@ -58,7 +56,6 @@ def log_metadata( *, metadata: Dict[str, MetadataType], model_version_id: UUID, - log_related_entities: Optional[bool] = True, ) -> None: ... @@ -68,7 +65,6 @@ def log_metadata( metadata: Dict[str, MetadataType], model_name: str, model_version: str, - log_related_entities: Optional[bool] = True, ) -> None: ... @@ -77,7 +73,6 @@ def log_metadata( *, metadata: Dict[str, MetadataType], step_id: UUID, - log_related_entities: Optional[bool] = True, ) -> None: ... @@ -86,7 +81,6 @@ def log_metadata( *, metadata: Dict[str, MetadataType], run_id_name_or_prefix: Union[UUID, str], - log_related_entities: Optional[bool] = True, ) -> None: ... @@ -96,7 +90,6 @@ def log_metadata( metadata: Dict[str, MetadataType], step_name: str, run_id_name_or_prefix: Union[UUID, str], - log_related_entities: Optional[bool] = True, ) -> None: ... @@ -139,60 +132,46 @@ def log_metadata( """ client = Client() - # If a step name is provided, we need a run_id_name_or_prefix and will log - # metadata for the steps pipeline and model accordingly. + # Log metadata to a step by name and run ID if step_name is not None and run_id_name_or_prefix is not None: - run = client.get_pipeline_run(run_id_name_or_prefix) - step = run.steps[step_name] - - resources = [(step.id, MetadataResourceTypes.STEP_RUN)] - - if log_related_entities: - resources.append((run.id, MetadataResourceTypes.PIPELINE_RUN)) - if step.model_version: - resources.append( - ( - step.model_version.id, - MetadataResourceTypes.MODEL_VERSION, - ) + step_model_id = client.get_pipeline_run( + name_id_or_prefix=run_id_name_or_prefix + )[step_name].id + client.create_run_metadata( + metadata=metadata, + resources=[ + RunMetadataResource( + id=step_model_id, type=MetadataResourceTypes.STEP_RUN ) - client.create_run_metadata(metadata=metadata, resources=resources) - # If a step is identified by id, fetch it directly through the client, - # follow a similar procedure and log metadata for its pipeline and model - # as well. - elif step_id is not None: - resources = [(step_id, MetadataResourceTypes.STEP_RUN)] - - if log_related_entities: - step = client.get_run_step(step_id) - resources.append( - (step.pipeline_run_id, MetadataResourceTypes.PIPELINE_RUN) - ) + ], + ) - if step.model_version: - resources.append( - ( - step.model_version.id, - MetadataResourceTypes.MODEL_VERSION, - ) + # Log metadata to a step by ID + elif step_id is not None: + client.create_run_metadata( + metadata=metadata, + resources=[ + RunMetadataResource( + id=step_id, type=MetadataResourceTypes.STEP_RUN ) - client.create_run_metadata(metadata=metadata, resources=resources) + ], + ) - # If a pipeline run id is identified, we need to log metadata to it and its - # model as well. + # Log metadata to a run by ID elif run_id_name_or_prefix is not None: - run = client.get_pipeline_run(run_id_name_or_prefix) - - resources = [(run.id, MetadataResourceTypes.PIPELINE_RUN)] - - if log_related_entities and run.model_version is not None: - resources.append( - (run.model_version.id, MetadataResourceTypes.MODEL_VERSION) - ) - client.create_run_metadata(metadata=metadata, resources=resources) + run_model = client.get_pipeline_run( + name_id_or_prefix=run_id_name_or_prefix + ) + client.create_run_metadata( + metadata=metadata, + resources=[ + RunMetadataResource( + id=run_model.id, type=MetadataResourceTypes.PIPELINE_RUN + ) + ], + ) - # If the user provides a model name and version, we use to model abstraction - # to fetch the model version and attach the corresponding metadata to it. + # Log metadata to a model version by name and version elif model_name is not None and model_version is not None: from zenml import Model @@ -200,16 +179,22 @@ def log_metadata( client.create_run_metadata( metadata=metadata, - resources=[(mv.id, MetadataResourceTypes.MODEL_VERSION)], + resources=[ + RunMetadataResource( + id=mv.id, type=MetadataResourceTypes.MODEL_VERSION + ) + ], ) - # If the user provides a model version id, we use the client to fetch it and - # attach the metadata to it. + # Log metadata to a model version by id elif model_version_id is not None: client.create_run_metadata( metadata=metadata, resources=[ - (model_version_id, MetadataResourceTypes.MODEL_VERSION) + RunMetadataResource( + id=model_version_id, + type=MetadataResourceTypes.MODEL_VERSION, + ) ], ) @@ -227,9 +212,9 @@ def log_metadata( client.create_run_metadata( metadata=metadata, resources=[ - ( - artifact_version_model.id, - MetadataResourceTypes.ARTIFACT_VERSION, + RunMetadataResource( + id=artifact_version_model.id, + type=MetadataResourceTypes.ARTIFACT_VERSION, ) ], ) @@ -249,9 +234,9 @@ def log_metadata( client.create_run_metadata( metadata=metadata, resources=[ - ( - artifact_version_model.id, - MetadataResourceTypes.ARTIFACT_VERSION, + RunMetadataResource( + id=artifact_version_model.id, + type=MetadataResourceTypes.ARTIFACT_VERSION, ) ], ) @@ -262,7 +247,10 @@ def log_metadata( client.create_run_metadata( metadata=metadata, resources=[ - (artifact_version_id, MetadataResourceTypes.ARTIFACT_VERSION) + RunMetadataResource( + id=artifact_version_id, + type=MetadataResourceTypes.ARTIFACT_VERSION, + ) ], ) @@ -293,20 +281,15 @@ def log_metadata( "of the step execution, please provide the required " "identifiers." ) - resources = [ - (step_context.step_run.id, MetadataResourceTypes.STEP_RUN), - (step_context.pipeline_run.id, MetadataResourceTypes.PIPELINE_RUN), - ] - if step_context.model_version: - resources.append( - ( - step_context.model_version.id, - MetadataResourceTypes.MODEL_VERSION, - ) - ) client.create_run_metadata( metadata=metadata, - resources=resources, + resources=[ + RunMetadataResource( + id=step_context.step_run.id, + type=MetadataResourceTypes.STEP_RUN, + ) + ], + publisher_step_id=step_context.step_run.id, ) else: @@ -316,23 +299,20 @@ def log_metadata( include: # Inside a step - # Logs the metadata to the step, its run and possibly its model log_metadata(metadata={}) - # Manually logging for a step - # Logs the metadata to the step, its run and possibly its model + # Manual logging for a step log_metadata(metadata={}, step_name=..., run_id_name_or_prefix=...) log_metadata(metadata={}, step_id=...) - # Manually logging for a run - # Logs the metadata to the run, possibly its model + # Manual logging for a run log_metadata(metadata={}, run_id_name_or_prefix=...) - # Manually logging for a model + # Manual logging for a model log_metadata(metadata={}, model_name=..., model_version=...) log_metadata(metadata={}, model_version_id=...) - # Manually logging for an artifact + # Manual logging for an artifact log_metadata(metadata={}, artifact_name=...) # inside a step log_metadata(metadata={}, artifact_name=..., artifact_version=...) log_metadata(metadata={}, artifact_version_id=...) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 297d5a37ca4..409199f0d24 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1015,20 +1015,20 @@ def create_run_metadata( ) for resource in run_metadata.resources: - if resource[1] == MetadataResourceTypes.PIPELINE_RUN: - run = zen_store().get_run(resource[0]) + if resource.type == MetadataResourceTypes.PIPELINE_RUN: + run = zen_store().get_run(resource.id) verify_permission_for_model(run, action=Action.UPDATE) - elif resource[1] == MetadataResourceTypes.STEP_RUN: - step = zen_store().get_run_step(resource[0]) + elif resource.type == MetadataResourceTypes.STEP_RUN: + step = zen_store().get_run_step(resource.id) verify_permission_for_model(step, action=Action.UPDATE) - elif resource[1] == MetadataResourceTypes.ARTIFACT_VERSION: - artifact_version = zen_store().get_artifact_version(resource[0]) + elif resource.type == MetadataResourceTypes.ARTIFACT_VERSION: + artifact_version = zen_store().get_artifact_version(resource.id) verify_permission_for_model(artifact_version, action=Action.UPDATE) - elif resource[1] == MetadataResourceTypes.MODEL_VERSION: - model_version = zen_store().get_model_version(resource[0]) + elif resource.type == MetadataResourceTypes.MODEL_VERSION: + model_version = zen_store().get_model_version(resource.id) verify_permission_for_model(model_version, action=Action.UPDATE) else: - raise RuntimeError(f"Unknown resource type: {resource[1]}") + raise RuntimeError(f"Unknown resource type: {resource.type}") verify_permission( resource_type=ResourceType.RUN_METADATA, action=Action.CREATE diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index 7697a721614..32c5c1749f6 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -244,12 +244,12 @@ class ArtifactVersionSchema(BaseSchema, table=True): workspace: "WorkspaceSchema" = Relationship( back_populates="artifact_versions" ) - run_metadata: List["RunMetadataResourceSchema"] = Relationship( + run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship( back_populates="artifact_version", sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)", cascade="delete", - overlaps="run_metadata", + overlaps="run_metadata_resources", ), ) output_of_step_runs: List["StepRunOutputArtifactSchema"] = Relationship( diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index ba22ab0aa57..5418c5a32b9 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -305,12 +305,12 @@ class ModelVersionSchema(NamedSchema, table=True): description: str = Field(sa_column=Column(TEXT, nullable=True)) stage: str = Field(sa_column=Column(TEXT, nullable=True)) - run_metadata: List["RunMetadataResourceSchema"] = Relationship( + run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship( back_populates="model_version", sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)", cascade="delete", - overlaps="run_metadata", + overlaps="run_metadata_resources", ), ) pipeline_runs: List["PipelineRunSchema"] = Relationship( @@ -407,7 +407,7 @@ def to_model( description=self.description, run_metadata={ m.run_metadata.key: json.loads(m.run_metadata.value) - for m in self.run_metadata + for m in self.run_metadata_resources }, ) diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index c49565d6bda..77c56ed562d 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -15,7 +15,7 @@ import json from datetime import datetime -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from uuid import UUID from pydantic import ConfigDict @@ -28,12 +28,14 @@ MetadataResourceTypes, TaggableResourceTypes, ) +from zenml.metadata.metadata_types import MetadataType from zenml.models import ( PipelineRunRequest, PipelineRunResponse, PipelineRunResponseBody, PipelineRunResponseMetadata, PipelineRunUpdate, + RunMetadataEntry, ) from zenml.models.v2.core.pipeline_run import PipelineRunResponseResources from zenml.zen_stores.schemas.base_schemas import NamedSchema @@ -138,12 +140,12 @@ class PipelineRunSchema(NamedSchema, table=True): ) workspace: "WorkspaceSchema" = Relationship(back_populates="runs") user: Optional["UserSchema"] = Relationship(back_populates="runs") - run_metadata: List["RunMetadataResourceSchema"] = Relationship( + run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship( back_populates="pipeline_run", sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)", cascade="delete", - overlaps="run_metadata", + overlaps="run_metadata_resources", ), ) logs: Optional["LogsSchema"] = Relationship( @@ -251,6 +253,44 @@ def from_request( model_version_id=request.model_version_id, ) + def fetch_metadata_collection( + self, latest_values_only: True + ) -> Dict[str, Union[MetadataType, List[RunMetadataEntry]]]: + """Fetches the metadata of related to the pipeline run. + + Returns: + a dictionary, where the key is the name of the metadata and the + values represent the entries under this name. + """ + metadata_dict = {} + + # Fetch the metadata related to this run + for rm in self.run_metadata_resources: + if rm.run_metadata.key not in metadata_dict: + metadata_dict[rm.run_metadata.key] = [] + metadata_dict[rm.run_metadata.key].append( + RunMetadataEntry( + value=json.loads(rm.run_metadata.value), + created=rm.run_metadata.created, + ) + ) + # Fetch the metadata related to the steps of this run + for s in self.step_runs: + step_metadata = s.fetch_metadata_collection( + latest_values_only=False + ) + for k, v in step_metadata.items(): + metadata_dict[f"{s.name}::{k}"] = v + + # If we get only the latest values, sort by created and get the first + if latest_values_only: + for k, v in metadata_dict.items(): + metadata_dict[k] = sorted( + v, key=lambda x: x.created, reverse=True + )[0].value + + return metadata_dict + def to_model( self, include_metadata: bool = False, @@ -277,10 +317,7 @@ def to_model( else {} ) - run_metadata = { - m.run_metadata.key: json.loads(m.run_metadata.value) - for m in self.run_metadata - } + run_metadata = self.fetch_metadata_collection(latest_values_only=True) if self.deployment is not None: deployment = self.deployment.to_model() diff --git a/src/zenml/zen_stores/schemas/run_metadata_schemas.py b/src/zenml/zen_stores/schemas/run_metadata_schemas.py index c927fd281f2..7b28b3bde98 100644 --- a/src/zenml/zen_stores/schemas/run_metadata_schemas.py +++ b/src/zenml/zen_stores/schemas/run_metadata_schemas.py @@ -23,6 +23,7 @@ from zenml.zen_stores.schemas.base_schemas import BaseSchema from zenml.zen_stores.schemas.component_schemas import StackComponentSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field +from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema from zenml.zen_stores.schemas.user_schemas import UserSchema from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema @@ -30,7 +31,6 @@ from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema from zenml.zen_stores.schemas.model_schemas import ModelVersionSchema from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema - from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema class RunMetadataSchema(BaseSchema, table=True): @@ -75,6 +75,16 @@ class RunMetadataSchema(BaseSchema, table=True): ) workspace: "WorkspaceSchema" = Relationship(back_populates="run_metadata") + publisher_step_id: UUID = build_foreign_key_field( + source=__tablename__, + target=StepRunSchema.__tablename__, + source_column="publisher_step_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + publisher_step: Optional["StepRunSchema"] = Relationship() + key: str value: str = Field(sa_column=Column(TEXT, nullable=False)) type: str @@ -102,30 +112,30 @@ class RunMetadataResourceSchema(SQLModel, table=True): # Relationship to link specific resource types pipeline_run: List["PipelineRunSchema"] = Relationship( - back_populates="run_metadata", + back_populates="run_metadata_resources", sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)", - overlaps="run_metadata,step_run,artifact_version,model_version", + overlaps="run_metadata_resources,step_run,artifact_version,model_version", ), ) step_run: List["StepRunSchema"] = Relationship( - back_populates="run_metadata", + back_populates="run_metadata_resources", sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)", - overlaps="run_metadata,pipeline_run,artifact_version,model_version", + overlaps="run_metadata_resources,pipeline_run,artifact_version,model_version", ), ) artifact_version: List["ArtifactVersionSchema"] = Relationship( - back_populates="run_metadata", + back_populates="run_metadata_resources", sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)", - overlaps="run_metadata,pipeline_run,step_run,model_version", + overlaps="run_metadata_resources,pipeline_run,step_run,model_version", ), ) model_version: List["ModelVersionSchema"] = Relationship( - back_populates="run_metadata", + back_populates="run_metadata_resources", sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)", - overlaps="run_metadata,pipeline_run,step_run,artifact_version", + overlaps="run_metadata_resources,pipeline_run,step_run,artifact_version", ), ) diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index f7ff4ce572d..5d9dd11cdc7 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -15,7 +15,7 @@ import json from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from uuid import UUID from pydantic import ConfigDict @@ -30,7 +30,9 @@ MetadataResourceTypes, StepRunInputArtifactType, ) +from zenml.metadata.metadata_types import MetadataType from zenml.models import ( + RunMetadataEntry, StepRunRequest, StepRunResponse, StepRunResponseBody, @@ -141,12 +143,12 @@ class StepRunSchema(NamedSchema, table=True): deployment: Optional["PipelineDeploymentSchema"] = Relationship( back_populates="step_runs" ) - run_metadata: List["RunMetadataResourceSchema"] = Relationship( + run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship( back_populates="step_run", sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)", cascade="delete", - overlaps="run_metadata", + overlaps="run_metadata_resources", ), ) input_artifacts: List["StepRunInputArtifactSchema"] = Relationship( @@ -168,6 +170,9 @@ class StepRunSchema(NamedSchema, table=True): model_version: "ModelVersionSchema" = Relationship( back_populates="step_runs", ) + original_step_run: Optional["StepRunSchema"] = Relationship( + sa_relationship_kwargs={"remote_side": "StepRunSchema.id"} + ) model_config = ConfigDict(protected_namespaces=()) # type: ignore[assignment] @@ -198,6 +203,50 @@ def from_request(cls, request: StepRunRequest) -> "StepRunSchema": model_version_id=request.model_version_id, ) + def fetch_metadata_collection( + self, latest_values_only: True + ) -> Dict[str, Union[MetadataType, List[RunMetadataEntry]]]: + """Fetches the metadata of related to the pipeline run. + + Returns: + a dictionary, where the key is the name of the metadata and the + values represent the entries under this name. + """ + metadata_dict = {} + + # Fetch the metadata related to this step + for rm in self.run_metadata_resources: + if rm.run_metadata.key not in metadata_dict: + metadata_dict[rm.run_metadata.key] = [] + metadata_dict[rm.run_metadata.key].append( + RunMetadataEntry( + value=json.loads(rm.run_metadata.value), + created=rm.run_metadata.created, + ) + ) + + # Fetch the metadata related to the original step of this cached step + if self.original_step_run: + for metadata in self.original_step_run.run_metadata_resources: + if metadata.publisher_step_id is not None: + if metadata.key not in metadata_dict: + metadata_dict[metadata.key] = [] + metadata_dict[metadata.key].append( + RunMetadataEntry( + value=json.loads(metadata.value), + created=metadata.created, + ) + ) + + # If we get only the latest values, sort by created and get the first + if latest_values_only: + for k, v in metadata_dict.items(): + metadata_dict[k] = sorted( + v, key=lambda x: x.created, reverse=True + )[0].value + + return metadata_dict + def to_model( self, include_metadata: bool = False, @@ -220,10 +269,7 @@ def to_model( RuntimeError: If the step run schema does not have a deployment_id or a step_configuration. """ - run_metadata = { - m.run_metadata.key: json.loads(m.run_metadata.value) - for m in self.run_metadata - } + run_metadata = self.fetch_metadata_collection(latest_values_only=True) input_artifacts = { artifact.name: StepRunInputResponse( diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index c2f743d40b8..9633c010ab9 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5555,19 +5555,19 @@ def create_run_metadata(self, run_metadata: RunMetadataRequest) -> None: key=key, value=json.dumps(value), type=type_, + publisher_step_id=run_metadata.publisher_step_id, ) session.add(run_metadata_schema) session.commit() for resource in run_metadata.resources: rm_resource_link = RunMetadataResourceSchema( - resource_id=resource[0], - resource_type=resource[1].value, + resource_id=resource.id, + resource_type=resource.type.value, run_metadata_id=run_metadata_schema.id, ) session.add(rm_resource_link) session.commit() - return None # ----------------------------- Schedules ----------------------------- diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index 6c916326342..c9f0cb3554a 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -63,6 +63,7 @@ PipelineBuildRequest, PipelineDeploymentRequest, PipelineRequest, + RunMetadataResource, StackResponse, ) from zenml.utils import io_utils @@ -484,7 +485,11 @@ def test_create_run_metadata_for_pipeline_run(clean_client_with_run: Client): # Assert that the created metadata is correct clean_client_with_run.create_run_metadata( metadata={"axel": "is awesome"}, - resources=[(pipeline_run.id, MetadataResourceTypes.PIPELINE_RUN)], + resources=[ + RunMetadataResource( + id=pipeline_run.id, type=MetadataResourceTypes.PIPELINE_RUN + ) + ], ) rm = clean_client_with_run.get_pipeline_run(pipeline_run.id).run_metadata @@ -500,7 +505,11 @@ def test_create_run_metadata_for_step_run(clean_client_with_run: Client): # Assert that the created metadata is correct clean_client_with_run.create_run_metadata( metadata={"axel": "is awesome"}, - resources=[(step_run.id, MetadataResourceTypes.STEP_RUN)], + resources=[ + RunMetadataResource( + id=step_run.id, type=MetadataResourceTypes.STEP_RUN + ) + ], ) rm = clean_client_with_run.get_run_step(step_run.id).run_metadata @@ -517,7 +526,10 @@ def test_create_run_metadata_for_artifact(clean_client_with_run: Client): clean_client_with_run.create_run_metadata( metadata={"axel": "is awesome"}, resources=[ - (artifact_version.id, MetadataResourceTypes.ARTIFACT_VERSION) + RunMetadataResource( + id=artifact_version.id, + type=MetadataResourceTypes.ARTIFACT_VERSION, + ) ], ) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index c894d567c71..ec1e8584663 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -102,6 +102,7 @@ ModelVersionUpdate, PipelineRunFilter, PipelineRunResponse, + RunMetadataResource, ServiceAccountFilter, ServiceAccountRequest, ServiceAccountUpdate, @@ -5500,7 +5501,7 @@ def test_metadata_full_cycle_with_cascade_deletion( RunMetadataRequest( user=client.active_user.id, workspace=client.active_workspace.id, - resources=[(resource.id, type_)], + resources=[RunMetadataResource(id=resource.id, type=type_)], values={"foo": "bar"}, types={"foo": MetadataTypeEnum.STRING}, stack_component_id=sc.id From 18e5f1ebadf654445f6450d54a0f27eb380a3437 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 20 Nov 2024 18:02:34 +0100 Subject: [PATCH 075/124] small fix --- src/zenml/zen_stores/schemas/step_run_schemas.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index 5d9dd11cdc7..d0122919ef5 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -228,13 +228,13 @@ def fetch_metadata_collection( # Fetch the metadata related to the original step of this cached step if self.original_step_run: for metadata in self.original_step_run.run_metadata_resources: - if metadata.publisher_step_id is not None: - if metadata.key not in metadata_dict: - metadata_dict[metadata.key] = [] - metadata_dict[metadata.key].append( + if metadata.run_metadata.publisher_step_id is not None: + if metadata.run_metadata.key not in metadata_dict: + metadata_dict[metadata.run_metadata.key] = [] + metadata_dict[metadata.run_metadata.key].append( RunMetadataEntry( - value=json.loads(metadata.value), - created=metadata.created, + value=json.loads(metadata.run_metadata.value), + created=metadata.run_metadata.created, ) ) From 05f8ab5397c08f7b3a8c0cc9dd17253b81dac548 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 26 Nov 2024 01:07:31 +0100 Subject: [PATCH 076/124] working checkpoint --- src/zenml/client.py | 7 ++-- src/zenml/models/v2/core/run_metadata.py | 6 ++- src/zenml/models/v2/misc/run_metadata.py | 2 + src/zenml/utils/metadata_utils.py | 10 +++-- .../cc269488e5a9_separate_run_metadata.py | 15 ++++++- .../schemas/pipeline_run_schemas.py | 23 ++++++----- .../schemas/run_metadata_schemas.py | 11 +---- .../zen_stores/schemas/step_run_schemas.py | 15 +------ src/zenml/zen_stores/sql_zen_store.py | 40 ++++++++++++++++++- 9 files changed, 82 insertions(+), 47 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index ede2f090f0a..c4619b6b998 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -4440,7 +4440,7 @@ def create_run_metadata( metadata: Dict[str, "MetadataType"], resources: List[RunMetadataResource], stack_component_id: Optional[UUID] = None, - publisher_step_id: Optional[UUID] = None, + cached: bool = False, ) -> None: """Create run metadata. @@ -4450,7 +4450,8 @@ def create_run_metadata( metadata was produced. stack_component_id: The ID of the stack component that produced the metadata. - publisher_step_id: The ID of the step which published this metadata. + cached: A flag indicating if the run metadata can be cached during + a step execution. Returns: None @@ -4484,7 +4485,7 @@ def create_run_metadata( user=self.active_user.id, resources=resources, stack_component_id=stack_component_id, - publisher_step_id=publisher_step_id, + cached=cached, values=values, types=types, ) diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index 97f966c4cf7..a53f1cba70e 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -43,6 +43,8 @@ class RunMetadataRequest(WorkspaceScopedRequest): types: Dict[str, "MetadataTypeEnum"] = Field( title="The types of the metadata to be created.", ) - publisher_step_id: Optional[UUID] = Field( - title="The ID of the step who published this metadata." + cached: Optional[bool] = Field( + title="A flag indicating if the run metadata is cached through " + "a step execution.", + default=False, ) diff --git a/src/zenml/models/v2/misc/run_metadata.py b/src/zenml/models/v2/misc/run_metadata.py index 5ace56a512d..3d5bb75ec69 100644 --- a/src/zenml/models/v2/misc/run_metadata.py +++ b/src/zenml/models/v2/misc/run_metadata.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. +"""Utility classes for modelling run metadata.""" + from datetime import datetime from uuid import UUID diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index c37b6739584..aa5d3bf4286 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -134,9 +134,11 @@ def log_metadata( # Log metadata to a step by name and run ID if step_name is not None and run_id_name_or_prefix is not None: - step_model_id = client.get_pipeline_run( - name_id_or_prefix=run_id_name_or_prefix - )[step_name].id + step_model_id = ( + client.get_pipeline_run(name_id_or_prefix=run_id_name_or_prefix) + .steps[step_name] + .id + ) client.create_run_metadata( metadata=metadata, resources=[ @@ -289,7 +291,7 @@ def log_metadata( type=MetadataResourceTypes.STEP_RUN, ) ], - publisher_step_id=step_context.step_run.id, + cached=True, ) else: diff --git a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py index c345b3ffe4d..922c0f14b41 100644 --- a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py +++ b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py @@ -1,7 +1,7 @@ """Separate run metadata into resource link table with new UUIDs. Revision ID: cc269488e5a9 -Revises: 0.70.0 +Revises: ec6307720f92 Create Date: 2024-11-12 09:46:46.587478 """ @@ -13,7 +13,7 @@ # revision identifiers, used by Alembic. revision = "cc269488e5a9" -down_revision = "0.70.0" +down_revision = "ec6307720f92" branch_labels = None depends_on = None @@ -86,6 +86,14 @@ def upgrade() -> None: op.drop_column("run_metadata", "resource_id") op.drop_column("run_metadata", "resource_type") + # Add the cached column to the database table + op.add_column( + "run_metadata", + sa.Column( + "cached", sa.Boolean(), nullable=True, server_default=sa.false() + ), + ) + def downgrade() -> None: """Reverts the 'run_metadata_resource' table and migrates data back.""" @@ -127,3 +135,6 @@ def downgrade() -> None: # Drop the `run_metadata_resource` table op.drop_table("run_metadata_resource") + + # Drop the cached column + op.drop_column("run_metadata", "cached") diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 77c56ed562d..48515a17c6f 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -254,7 +254,7 @@ def from_request( ) def fetch_metadata_collection( - self, latest_values_only: True + self, latest_values_only: bool = True ) -> Dict[str, Union[MetadataType, List[RunMetadataEntry]]]: """Fetches the metadata of related to the pipeline run. @@ -262,34 +262,35 @@ def fetch_metadata_collection( a dictionary, where the key is the name of the metadata and the values represent the entries under this name. """ - metadata_dict = {} + metadata_collection: dict[str, List[RunMetadataEntry]] = {} # Fetch the metadata related to this run for rm in self.run_metadata_resources: - if rm.run_metadata.key not in metadata_dict: - metadata_dict[rm.run_metadata.key] = [] - metadata_dict[rm.run_metadata.key].append( + if rm.run_metadata.key not in metadata_collection: + metadata_collection[rm.run_metadata.key] = [] + metadata_collection[rm.run_metadata.key].append( RunMetadataEntry( value=json.loads(rm.run_metadata.value), created=rm.run_metadata.created, ) ) + # Fetch the metadata related to the steps of this run for s in self.step_runs: step_metadata = s.fetch_metadata_collection( latest_values_only=False ) for k, v in step_metadata.items(): - metadata_dict[f"{s.name}::{k}"] = v + metadata_collection[f"{s.name}::{k}"] = v # If we get only the latest values, sort by created and get the first if latest_values_only: - for k, v in metadata_dict.items(): - metadata_dict[k] = sorted( - v, key=lambda x: x.created, reverse=True - )[0].value + return { + k: sorted(v, key=lambda x: x.created, reverse=True)[0].value + for k, v in metadata_collection.items() + } - return metadata_dict + return metadata_collection def to_model( self, diff --git a/src/zenml/zen_stores/schemas/run_metadata_schemas.py b/src/zenml/zen_stores/schemas/run_metadata_schemas.py index 7b28b3bde98..8c61ff5c98a 100644 --- a/src/zenml/zen_stores/schemas/run_metadata_schemas.py +++ b/src/zenml/zen_stores/schemas/run_metadata_schemas.py @@ -75,19 +75,10 @@ class RunMetadataSchema(BaseSchema, table=True): ) workspace: "WorkspaceSchema" = Relationship(back_populates="run_metadata") - publisher_step_id: UUID = build_foreign_key_field( - source=__tablename__, - target=StepRunSchema.__tablename__, - source_column="publisher_step_id", - target_column="id", - ondelete="SET NULL", - nullable=True, - ) - publisher_step: Optional["StepRunSchema"] = Relationship() - key: str value: str = Field(sa_column=Column(TEXT, nullable=False)) type: str + cached: Optional[bool] = Field(default=False) class RunMetadataResourceSchema(SQLModel, table=True): diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index d0122919ef5..a0a568ceafc 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -204,7 +204,7 @@ def from_request(cls, request: StepRunRequest) -> "StepRunSchema": ) def fetch_metadata_collection( - self, latest_values_only: True + self, latest_values_only: bool = True ) -> Dict[str, Union[MetadataType, List[RunMetadataEntry]]]: """Fetches the metadata of related to the pipeline run. @@ -225,19 +225,6 @@ def fetch_metadata_collection( ) ) - # Fetch the metadata related to the original step of this cached step - if self.original_step_run: - for metadata in self.original_step_run.run_metadata_resources: - if metadata.run_metadata.publisher_step_id is not None: - if metadata.run_metadata.key not in metadata_dict: - metadata_dict[metadata.run_metadata.key] = [] - metadata_dict[metadata.run_metadata.key].append( - RunMetadataEntry( - value=json.loads(metadata.run_metadata.value), - created=metadata.run_metadata.created, - ) - ) - # If we get only the latest values, sort by created and get the first if latest_values_only: for k, v in metadata_dict.items(): diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 9633c010ab9..dcdfc75ab0a 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5555,7 +5555,7 @@ def create_run_metadata(self, run_metadata: RunMetadataRequest) -> None: key=key, value=json.dumps(value), type=type_, - publisher_step_id=run_metadata.publisher_step_id, + cached=run_metadata.cached, ) session.add(run_metadata_schema) session.commit() @@ -8178,6 +8178,44 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse: ) session.add(log_entry) + # If cached, attach metadata of the original step + if ( + step_run.status == ExecutionStatus.CACHED + and step_run.original_step_run_id is not None + ): + original_metadata_links = session.exec( + select(RunMetadataResourceSchema) + .join( + RunMetadataSchema, + RunMetadataResourceSchema.run_metadata_id + == RunMetadataSchema.id, + ) + .where( + RunMetadataResourceSchema.resource_id + == step_run.original_step_run_id + ) + .where( + RunMetadataResourceSchema.resource_type + == MetadataResourceTypes.STEP_RUN + ) + .where(RunMetadataSchema.cached.is_(True)) + ).all() + + # Create new links in a batch + new_links = [ + RunMetadataResourceSchema( + resource_id=step_schema.id, + resource_type=link.resource_type, + run_metadata_id=link.run_metadata_id, + ) + for link in original_metadata_links + ] + # Add all new links in a single operation + session.add_all(new_links) + # Commit the changes + session.commit() + session.refresh(step_schema) + # Save parent step IDs into the database. for parent_step_id in step_run.parent_step_ids: self._set_run_step_parent_step( From 8a353b8703e69fdbac08f865c816b0a85d79baa8 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 26 Nov 2024 01:56:27 +0100 Subject: [PATCH 077/124] fixes, linting, docstrings --- src/zenml/client.py | 3 -- .../zen_stores/schemas/artifact_schemas.py | 44 ++++++++++++++++--- src/zenml/zen_stores/schemas/model_schemas.py | 42 ++++++++++++++++-- .../schemas/pipeline_run_schemas.py | 40 +++++++++-------- .../zen_stores/schemas/step_run_schemas.py | 42 ++++++++++-------- src/zenml/zen_stores/schemas/utils.py | 3 +- src/zenml/zen_stores/sql_zen_store.py | 9 ++-- 7 files changed, 127 insertions(+), 56 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index c4619b6b998..cf11f646db5 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -4452,9 +4452,6 @@ def create_run_metadata( the metadata. cached: A flag indicating if the run metadata can be cached during a step execution. - - Returns: - None """ from zenml.metadata.metadata_types import get_metadata_type diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index 32c5c1749f6..58b5e358b78 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -15,7 +15,7 @@ import json from datetime import datetime -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional from uuid import UUID from pydantic import ValidationError @@ -30,6 +30,7 @@ MetadataResourceTypes, TaggableResourceTypes, ) +from zenml.metadata.metadata_types import MetadataType from zenml.models import ( ArtifactResponse, ArtifactResponseBody, @@ -40,6 +41,7 @@ ArtifactVersionResponseBody, ArtifactVersionResponseMetadata, ArtifactVersionUpdate, + RunMetadataEntry, ) from zenml.models.v2.core.artifact import ArtifactRequest from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema @@ -308,6 +310,41 @@ def from_request( save_type=artifact_version_request.save_type.value, ) + def fetch_metadata_collection(self) -> Dict[str, List[RunMetadataEntry]]: + """Fetches all the metadata entries related to the artifact version. + + Returns: + a dictionary, where the key is the key of the metadata entry + and the values represent the list of entries with this key. + """ + metadata_collection: Dict[str, List[RunMetadataEntry]] = {} + + # Fetch the metadata related to this step + for rm in self.run_metadata_resources: + if rm.run_metadata.key not in metadata_collection: + metadata_collection[rm.run_metadata.key] = [] + metadata_collection[rm.run_metadata.key].append( + RunMetadataEntry( + value=json.loads(rm.run_metadata.value), + created=rm.run_metadata.created, + ) + ) + + return metadata_collection + + def fetch_metadata(self) -> Dict[str, MetadataType]: + """Fetches the latest metadata entry related to the artifact version. + + Returns: + a dictionary, where the key is the key of the metadata entry + and the values represent the latest entry with this key. + """ + metadata_collection = self.fetch_metadata_collection() + return { + k: sorted(v, key=lambda x: x.created, reverse=True)[0].value + for k, v in metadata_collection.items() + } + def to_model( self, include_metadata: bool = False, @@ -377,10 +414,7 @@ def to_model( workspace=self.workspace.to_model(), producer_step_run_id=producer_step_run_id, visualizations=[v.to_model() for v in self.visualizations], - run_metadata={ - m.run_metadata.key: json.loads(m.run_metadata.value) - for m in self.run_metadata - }, + run_metadata=self.fetch_metadata(), ) resources = None diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index a8e511da3c4..ce082e6f76a 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -27,6 +27,7 @@ MetadataResourceTypes, TaggableResourceTypes, ) +from zenml.metadata.metadata_types import MetadataType from zenml.models import ( BaseResponseMetadata, ModelRequest, @@ -46,6 +47,7 @@ ModelVersionResponseMetadata, ModelVersionResponseResources, Page, + RunMetadataEntry, ) from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema @@ -346,6 +348,41 @@ def from_request( stage=model_version_request.stage, ) + def fetch_metadata_collection(self) -> Dict[str, List[RunMetadataEntry]]: + """Fetches all the metadata entries related to the model version. + + Returns: + a dictionary, where the key is the key of the metadata entry + and the values represent the list of entries with this key. + """ + metadata_collection: Dict[str, List[RunMetadataEntry]] = {} + + # Fetch the metadata related to this step + for rm in self.run_metadata_resources: + if rm.run_metadata.key not in metadata_collection: + metadata_collection[rm.run_metadata.key] = [] + metadata_collection[rm.run_metadata.key].append( + RunMetadataEntry( + value=json.loads(rm.run_metadata.value), + created=rm.run_metadata.created, + ) + ) + + return metadata_collection + + def fetch_metadata(self) -> Dict[str, MetadataType]: + """Fetches the latest metadata entry related to the model version. + + Returns: + a dictionary, where the key is the key of the metadata entry + and the values represent the latest entry with this key. + """ + metadata_collection = self.fetch_metadata_collection() + return { + k: sorted(v, key=lambda x: x.created, reverse=True)[0].value + for k, v in metadata_collection.items() + } + def to_model( self, include_metadata: bool = False, @@ -404,10 +441,7 @@ def to_model( metadata = ModelVersionResponseMetadata( workspace=self.workspace.to_model(), description=self.description, - run_metadata={ - m.run_metadata.key: json.loads(m.run_metadata.value) - for m in self.run_metadata_resources - }, + run_metadata=self.fetch_metadata(), ) resources = None diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 48515a17c6f..02cae95044c 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -15,7 +15,7 @@ import json from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional from uuid import UUID from pydantic import ConfigDict @@ -253,16 +253,14 @@ def from_request( model_version_id=request.model_version_id, ) - def fetch_metadata_collection( - self, latest_values_only: bool = True - ) -> Dict[str, Union[MetadataType, List[RunMetadataEntry]]]: - """Fetches the metadata of related to the pipeline run. + def fetch_metadata_collection(self) -> Dict[str, List[RunMetadataEntry]]: + """Fetches all the metadata entries related to the pipeline run. Returns: - a dictionary, where the key is the name of the metadata and the - values represent the entries under this name. + a dictionary, where the key is the key of the metadata entry + and the values represent the list of entries with this key. """ - metadata_collection: dict[str, List[RunMetadataEntry]] = {} + metadata_collection: Dict[str, List[RunMetadataEntry]] = {} # Fetch the metadata related to this run for rm in self.run_metadata_resources: @@ -277,21 +275,25 @@ def fetch_metadata_collection( # Fetch the metadata related to the steps of this run for s in self.step_runs: - step_metadata = s.fetch_metadata_collection( - latest_values_only=False - ) + step_metadata = s.fetch_metadata_collection() for k, v in step_metadata.items(): metadata_collection[f"{s.name}::{k}"] = v - # If we get only the latest values, sort by created and get the first - if latest_values_only: - return { - k: sorted(v, key=lambda x: x.created, reverse=True)[0].value - for k, v in metadata_collection.items() - } - return metadata_collection + def fetch_metadata(self) -> Dict[str, MetadataType]: + """Fetches the latest metadata entry related to the pipeline run. + + Returns: + a dictionary, where the key is the key of the metadata entry + and the values represent the latest entry with this key. + """ + metadata_collection = self.fetch_metadata_collection() + return { + k: sorted(v, key=lambda x: x.created, reverse=True)[0].value + for k, v in metadata_collection.items() + } + def to_model( self, include_metadata: bool = False, @@ -318,7 +320,7 @@ def to_model( else {} ) - run_metadata = self.fetch_metadata_collection(latest_values_only=True) + run_metadata = self.fetch_metadata() if self.deployment is not None: deployment = self.deployment.to_model() diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index a0a568ceafc..a00e8cdfe94 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -15,7 +15,7 @@ import json from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional from uuid import UUID from pydantic import ConfigDict @@ -203,36 +203,40 @@ def from_request(cls, request: StepRunRequest) -> "StepRunSchema": model_version_id=request.model_version_id, ) - def fetch_metadata_collection( - self, latest_values_only: bool = True - ) -> Dict[str, Union[MetadataType, List[RunMetadataEntry]]]: - """Fetches the metadata of related to the pipeline run. + def fetch_metadata_collection(self) -> Dict[str, List[RunMetadataEntry]]: + """Fetches all the metadata entries related to the step run. Returns: - a dictionary, where the key is the name of the metadata and the - values represent the entries under this name. + a dictionary, where the key is the key of the metadata entry + and the values represent the list of entries with this key. """ - metadata_dict = {} + metadata_collection: Dict[str, List[RunMetadataEntry]] = {} # Fetch the metadata related to this step for rm in self.run_metadata_resources: - if rm.run_metadata.key not in metadata_dict: - metadata_dict[rm.run_metadata.key] = [] - metadata_dict[rm.run_metadata.key].append( + if rm.run_metadata.key not in metadata_collection: + metadata_collection[rm.run_metadata.key] = [] + metadata_collection[rm.run_metadata.key].append( RunMetadataEntry( value=json.loads(rm.run_metadata.value), created=rm.run_metadata.created, ) ) - # If we get only the latest values, sort by created and get the first - if latest_values_only: - for k, v in metadata_dict.items(): - metadata_dict[k] = sorted( - v, key=lambda x: x.created, reverse=True - )[0].value + return metadata_collection - return metadata_dict + def fetch_metadata(self) -> Dict[str, MetadataType]: + """Fetches the latest metadata entry related to the step run. + + Returns: + a dictionary, where the key is the key of the metadata entry + and the values represent the latest entry with this key. + """ + metadata_collection = self.fetch_metadata_collection() + return { + k: sorted(v, key=lambda x: x.created, reverse=True)[0].value + for k, v in metadata_collection.items() + } def to_model( self, @@ -256,7 +260,7 @@ def to_model( RuntimeError: If the step run schema does not have a deployment_id or a step_configuration. """ - run_metadata = self.fetch_metadata_collection(latest_values_only=True) + run_metadata = self.fetch_metadata() input_artifacts = { artifact.name: StepRunInputResponse( diff --git a/src/zenml/zen_stores/schemas/utils.py b/src/zenml/zen_stores/schemas/utils.py index ad458a5423e..3c520d1eb86 100644 --- a/src/zenml/zen_stores/schemas/utils.py +++ b/src/zenml/zen_stores/schemas/utils.py @@ -16,8 +16,7 @@ import math from typing import List, Type, TypeVar -from zenml.models.v2.base.base import BaseResponse -from zenml.models.v2.base.page import Page +from zenml.models import BaseResponse, Page from zenml.zen_stores.schemas.base_schemas import BaseSchema S = TypeVar("S", bound=BaseSchema) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index dcdfc75ab0a..5aa56d9a32a 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -8185,10 +8185,9 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse: ): original_metadata_links = session.exec( select(RunMetadataResourceSchema) - .join( - RunMetadataSchema, + .where( RunMetadataResourceSchema.run_metadata_id - == RunMetadataSchema.id, + == RunMetadataSchema.id ) .where( RunMetadataResourceSchema.resource_id @@ -8198,7 +8197,9 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse: RunMetadataResourceSchema.resource_type == MetadataResourceTypes.STEP_RUN ) - .where(RunMetadataSchema.cached.is_(True)) + .where( + RunMetadataSchema.cached == True # noqa: E712 + ) ).all() # Create new links in a batch From af009b0b41fa365620ac4f8bdc08072f583a84f9 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 26 Nov 2024 02:05:56 +0100 Subject: [PATCH 078/124] fixing unit tests --- src/zenml/zen_stores/sql_zen_store.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 5aa56d9a32a..f6802c292f0 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -220,6 +220,7 @@ PipelineRunUpdate, PipelineUpdate, RunMetadataRequest, + RunMetadataResource, RunTemplateFilter, RunTemplateRequest, RunTemplateResponse, @@ -2946,9 +2947,9 @@ def create_artifact_version( workspace=artifact_version.workspace, user=artifact_version.user, resources=[ - ( - artifact_version_id, - MetadataResourceTypes.ARTIFACT_VERSION, + RunMetadataResource( + id=artifact_version_id, + type=MetadataResourceTypes.ARTIFACT_VERSION, ) ], values=values, From fac74c9630502c8e7e894e2e972d392606a7f29b Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 26 Nov 2024 02:17:54 +0100 Subject: [PATCH 079/124] docs updates 1 --- .../attach-metadata-to-a-model.md | 5 ++++ .../attach-metadata-to-a-run.md | 24 ++++++++++--------- .../attach-metadata-to-a-step.md | 16 ++++++------- .../attach-metadata-to-an-artifact.md | 5 ++++ 4 files changed, 30 insertions(+), 20 deletions(-) diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md index eefb799547b..386b3b60b8a 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md @@ -94,4 +94,9 @@ model = client.get_model_version("my_model", "my_version") print(model.run_metadata["metadata_key"]) ``` +{% hint style="info" %} +When you are fetching metadata using a specific key, the returned value will +always reflect the latest entry. +{% endhint %} +
ZenML Scarf
diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md index 5d1495b79c6..ca2aa34844a 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md @@ -14,8 +14,9 @@ custom types like `Uri`, `Path`, `DType`, and `StorageSize`. If you are logging metadata from within a step that’s part of a pipeline run, calling `log_metadata` will attach the specified metadata to the current -pipeline run. This is especially useful for logging details about the run -while it's still active. +pipeline run where the metadata key will have the `step_name:metadata_key` +pattern. This allows you to use the same metadata key from different steps +while the run's still executing. ```python from typing import Annotated @@ -49,15 +50,9 @@ def train_model(dataset: pd.DataFrame) -> Annotated[ return classifier ``` -{% hint style="warning" %} -In order to log metadata to a pipeline run during the step execution without -specifying any additional identifiers, `log_related_entities` should be -`True` (default behavior). -{% endhint %} - -## Logging Metadata Outside a Run +## Manually Logging Metadata to a Pipeline Run -You can also attach metadata to a specific pipeline run after its execution, +You can also attach metadata to a specific pipeline run without needing a step, using identifiers like the run ID. This is useful when logging information or metrics that were calculated post-execution. @@ -82,4 +77,11 @@ client = Client() run = client.get_pipeline_run("run_id_name_or_prefix") print(run.run_metadata["metadata_key"]) -``` \ No newline at end of file +``` + +{% hint style="info" %} +When you are fetching metadata using a specific key, the returned value will +always reflect the latest entry. +{% endhint %} + +
ZenML Scarf
diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md index b7fd687a412..224547e966e 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md @@ -49,14 +49,7 @@ def train_model(dataset: pd.DataFrame) -> Annotated[ return classifier ``` -{% hint style="info" %} -If you do not want to log the same metadata for the related entries such as -the pipeline run and the model version, you can set the `log_related_entities` -to `False` when you call `log_metadata`. -{% endhint %} - - -## Logging Metadata Outside a Step +## Manually Logging Metadata a Step Run You can also log metadata for a specific step after execution, using identifiers to specify the pipeline, step, and run. This approach is @@ -92,9 +85,14 @@ the ZenML Client: from zenml.client import Client client = Client() -step = client.get_pipeline_run().steps["step_name"] +step = client.get_pipeline_run("pipeline_id").steps["step_name"] print(step.run_metadata["metadata_key"]) ``` +{% hint style="info" %} +When you are fetching metadata using a specific key, the returned value will +always reflect the latest entry. +{% endhint %} +
ZenML Scarf
diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md index fd0e3aa96e3..5f2f962f80b 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md @@ -88,6 +88,11 @@ artifact = client.get_artifact_version("my_artifact", "my_version") print(artifact.run_metadata["metadata_key"]) ``` +{% hint style="info" %} +When you are fetching metadata using a specific key, the returned value will +always reflect the latest entry. +{% endhint %} + ## Grouping Metadata in the Dashboard When logging metadata passing a dictionary of dictionaries in the `metadata` From 4370bdc095d3e0e5273f27bd597ee85ea8507575 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 26 Nov 2024 02:22:20 +0100 Subject: [PATCH 080/124] docs update 2 --- .../track-metrics-metadata/attach-metadata-to-a-step.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md index 224547e966e..e53b49a8274 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md @@ -49,6 +49,13 @@ def train_model(dataset: pd.DataFrame) -> Annotated[ return classifier ``` +{% hint style="info" %} +If you run a pipeline where the step execution is cached, the cached step run +will copy the metadata that was created in the original step execution. +(If there is any metadata that was generated manually after the execution of +the original step, these entries will not be included in this process.) +{% endhint %} + ## Manually Logging Metadata a Step Run You can also log metadata for a specific step after execution, using From 6b06a901764270dcfab1f6d84577aedf9bca39e0 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 26 Nov 2024 02:28:35 +0100 Subject: [PATCH 081/124] fixing integration tests --- .../functional/model/test_model_version.py | 12 ++++++++++-- .../functional/zen_stores/test_zen_store.py | 13 ++----------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/integration/functional/model/test_model_version.py b/tests/integration/functional/model/test_model_version.py index 7266a541146..d16b9dc31bd 100644 --- a/tests/integration/functional/model/test_model_version.py +++ b/tests/integration/functional/model/test_model_version.py @@ -107,10 +107,18 @@ def __exit__(self, exc_type, exc_value, exc_traceback): @step def step_metadata_logging_functional(mdl_name: str): """Functional logging using implicit Model from context.""" - log_metadata({"foo": "bar"}) + model = get_step_context().model + + log_metadata( + metadata={"foo": "bar"}, + model_name=model.name, + model_version=model.version, + ) assert get_step_context().model.run_metadata["foo"] == "bar" log_metadata( - metadata={"foo": "bar"}, model_name=mdl_name, model_version="other" + metadata={"foo": "bar"}, + model_name=mdl_name, + model_version="other", ) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 13e34cb1c15..d05cf04a577 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -2919,7 +2919,8 @@ def pipeline_to_log_metadata(metadata): step_to_log_metadata(metadata) -def test_pipeline_run_filters_with_oneof_and_run_metadata(clean_client): +def \ + test_pipeline_run_filters_with_oneof_and_run_metadata(clean_client): store = clean_client.zen_store metadata_values = [3, 25, 100, "random_string", True] @@ -2955,16 +2956,6 @@ def test_pipeline_run_filters_with_oneof_and_run_metadata(clean_client): with pytest.raises(ValidationError): PipelineRunFilter(name="oneof:random_value") - # Test metadata filtering - runs_filter = PipelineRunFilter(run_metadata={"blupus": "lt:30"}) - runs = store.list_runs(runs_filter_model=runs_filter) - assert len(runs) == 2 # The run with 3 and 25 - - for r in runs: - assert "blupus" in r.run_metadata - assert isinstance(r.run_metadata["blupus"], int) - assert r.run_metadata["blupus"] < 30 - # .--------------------. # | Pipeline run steps | From 59d9867b645fa153f9ebb0e76fe20f9506301b62 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 26 Nov 2024 02:30:16 +0100 Subject: [PATCH 082/124] spellcheck --- src/zenml/models/v2/misc/run_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/models/v2/misc/run_metadata.py b/src/zenml/models/v2/misc/run_metadata.py index 3d5bb75ec69..1769ff30ad6 100644 --- a/src/zenml/models/v2/misc/run_metadata.py +++ b/src/zenml/models/v2/misc/run_metadata.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -"""Utility classes for modelling run metadata.""" +"""Utility classes for modeling run metadata.""" from datetime import datetime from uuid import UUID From e482c61bb669eea8fd43335af49bb64b84d8c658 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 26 Nov 2024 02:35:38 +0100 Subject: [PATCH 083/124] formatting --- tests/integration/functional/zen_stores/test_zen_store.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index d05cf04a577..7149a6ed71e 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -2919,8 +2919,7 @@ def pipeline_to_log_metadata(metadata): step_to_log_metadata(metadata) -def \ - test_pipeline_run_filters_with_oneof_and_run_metadata(clean_client): +def test_pipeline_run_filters_with_oneof_and_run_metadata(clean_client): store = clean_client.zen_store metadata_values = [3, 25, 100, "random_string", True] From ffc578746db03074d57b008f01c731b6553c30cc Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Tue, 26 Nov 2024 01:44:52 +0000 Subject: [PATCH 084/124] Auto-update of E2E template --- examples/e2e/.copier-answers.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/e2e/.copier-answers.yml b/examples/e2e/.copier-answers.yml index 0a2f40d5a92..38c2abb88a0 100644 --- a/examples/e2e/.copier-answers.yml +++ b/examples/e2e/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.11.20 +_commit: 2024.11.20-1-gd8d1576 _src_path: gh:zenml-io/template-e2e-batch data_quality_checks: true email: info@zenml.io From dd69dd7cf765804b6f95a40c6f18f7df090f26da Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 27 Nov 2024 15:35:43 +0100 Subject: [PATCH 085/124] docs changes --- .../track-metrics-metadata/attach-metadata-to-a-model.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md index 386b3b60b8a..a1c22cb8872 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md @@ -49,8 +49,7 @@ def train_model(dataset: pd.DataFrame) -> Annotated[ "recall": recall } }, - model_name=step_context.model.name, - model_version=step_context.model.version, + model_version_id=step_context.model.id, ) return classifier From eab5ba52ede651e2b860cc0e4c73f563b13f7fdb Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 27 Nov 2024 16:04:37 +0100 Subject: [PATCH 086/124] review comments --- .github/workflows/update-templates-to-examples.yml | 1 - src/zenml/cli/base.py | 1 - src/zenml/utils/metadata_utils.py | 6 +----- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/.github/workflows/update-templates-to-examples.yml b/.github/workflows/update-templates-to-examples.yml index 9bdb5bfbbb6..f58b2a9424f 100644 --- a/.github/workflows/update-templates-to-examples.yml +++ b/.github/workflows/update-templates-to-examples.yml @@ -46,7 +46,6 @@ jobs: python-version: ${{ inputs.python-version }} stack-name: local ref-zenml: ${{ github.ref }} - # TODO: Update to a newer date ref-template: 2024.11.13 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py - name: Clean-up run: | diff --git a/src/zenml/cli/base.py b/src/zenml/cli/base.py index 42c2006ddbf..1d5adecae0b 100644 --- a/src/zenml/cli/base.py +++ b/src/zenml/cli/base.py @@ -79,7 +79,6 @@ def copier_github_url(self) -> str: ZENML_PROJECT_TEMPLATES = dict( e2e_batch=ZenMLProjectTemplateLocation( github_url="zenml-io/template-e2e-batch", - # TODO: Update to a newer date github_tag="2024.11.13", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), starter=ZenMLProjectTemplateLocation( diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index aa5d3bf4286..49991fda47b 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -107,8 +107,6 @@ def log_metadata( model_version_id: Optional[UUID] = None, model_name: Optional[str] = None, model_version: Optional[str] = None, - # Parameter to adjust whether we log to all related entities - log_related_entities: Optional[bool] = True, ) -> None: """Logs metadata for various resource types in a generalized way. @@ -122,9 +120,7 @@ def log_metadata( artifact_version: The version of the artifact. model_version_id: The ID of the model version. model_name: The name of the model. - model_version: The version of the model - log_related_entities: Flag to decide whether we should log the same - metadata for related entities. + model_version: The version of the model. Raises: ValueError: If no identifiers are provided and the function is not From 81e59205290d989f3cc15449132816c91fc994f6 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 27 Nov 2024 16:29:17 +0100 Subject: [PATCH 087/124] added the batch rbac call --- .../routers/workspaces_endpoints.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 0e055cc81dc..b34bff32ad4 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Endpoint definitions for workspaces.""" -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -105,6 +105,7 @@ ) from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import ( + batch_verify_permissions_for_models, get_allowed_resource_ids, verify_permission, verify_permission_for_model, @@ -1009,22 +1010,24 @@ def create_run_metadata( "is not supported." ) + verify_models: List[Any] = [] for resource in run_metadata.resources: if resource.type == MetadataResourceTypes.PIPELINE_RUN: - run = zen_store().get_run(resource.id) - verify_permission_for_model(run, action=Action.UPDATE) + verify_models.append(zen_store().get_run(resource.id)) elif resource.type == MetadataResourceTypes.STEP_RUN: - step = zen_store().get_run_step(resource.id) - verify_permission_for_model(step, action=Action.UPDATE) + verify_models.append(zen_store().get_run_step(resource.id)) elif resource.type == MetadataResourceTypes.ARTIFACT_VERSION: - artifact_version = zen_store().get_artifact_version(resource.id) - verify_permission_for_model(artifact_version, action=Action.UPDATE) + verify_models.append(zen_store().get_artifact_version(resource.id)) elif resource.type == MetadataResourceTypes.MODEL_VERSION: - model_version = zen_store().get_model_version(resource.id) - verify_permission_for_model(model_version, action=Action.UPDATE) + verify_models.append(zen_store().get_model_version(resource.id)) else: raise RuntimeError(f"Unknown resource type: {resource.type}") + batch_verify_permissions_for_models( + models=verify_models, + action=Action.UPDATE, + ) + verify_permission( resource_type=ResourceType.RUN_METADATA, action=Action.CREATE ) From 0f785ff0a0c0bbd823cc590ce5b11a15fed492af Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 27 Nov 2024 16:52:21 +0100 Subject: [PATCH 088/124] added a validator to check the name of the keys --- src/zenml/models/v2/core/run_metadata.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index a53f1cba70e..9579f202def 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -16,7 +16,7 @@ from typing import Dict, List, Optional from uuid import UUID -from pydantic import Field +from pydantic import Field, model_validator from zenml.metadata.metadata_types import MetadataType, MetadataTypeEnum from zenml.models.v2.base.scoped import ( @@ -48,3 +48,19 @@ class RunMetadataRequest(WorkspaceScopedRequest): "a step execution.", default=False, ) + + @model_validator(mode="after") + def validate_values_keys(self) -> "RunMetadataRequest": + """Validates if the keys in the metadata are properly defined. + + Returns: + self + """ + invalid_keys = [key for key in self.values.keys() if ":" in key] + if invalid_keys: + raise ValueError( + "You can not use colons (`:`) in the key names when you " + "are creating metadata for your ZenML objects. Please change " + f"the following keys: {invalid_keys}" + ) + return self From 0988141524ade9452ddfd7d8c2301988de7e0d51 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 27 Nov 2024 17:44:42 +0100 Subject: [PATCH 089/124] small adjustments --- tests/integration/functional/zen_stores/test_zen_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index efefc5ef861..3c87f39e190 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -2910,7 +2910,7 @@ def test_deleting_run_deletes_steps(): @step def step_to_log_metadata(metadata: Union[str, int, bool]) -> int: - log_metadata({"blupus": metadata}) + log_metadata(metadata={"blupus": metadata}) return 42 From cc297f239a3d209725989c056552268c02cb39fb Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 28 Nov 2024 10:45:15 +0100 Subject: [PATCH 090/124] base schema added --- .../zen_stores/schemas/artifact_schemas.py | 48 +++--------------- src/zenml/zen_stores/schemas/model_schemas.py | 50 ++++--------------- .../schemas/pipeline_run_schemas.py | 29 ++--------- .../zen_stores/schemas/step_run_schemas.py | 44 ++-------------- src/zenml/zen_stores/schemas/utils.py | 49 +++++++++++++++++- 5 files changed, 71 insertions(+), 149 deletions(-) diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index 58b5e358b78..27c69b680b6 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -13,9 +13,8 @@ # permissions and limitations under the License. """SQLModel implementation of artifact table.""" -import json from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional from uuid import UUID from pydantic import ValidationError @@ -30,7 +29,6 @@ MetadataResourceTypes, TaggableResourceTypes, ) -from zenml.metadata.metadata_types import MetadataType from zenml.models import ( ArtifactResponse, ArtifactResponseBody, @@ -41,10 +39,12 @@ ArtifactVersionResponseBody, ArtifactVersionResponseMetadata, ArtifactVersionUpdate, - RunMetadataEntry, ) from zenml.models.v2.core.artifact import ArtifactRequest -from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema +from zenml.zen_stores.schemas.base_schemas import ( + BaseSchema, + NamedSchema, +) from zenml.zen_stores.schemas.component_schemas import StackComponentSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.step_run_schemas import ( @@ -52,6 +52,7 @@ StepRunOutputArtifactSchema, ) from zenml.zen_stores.schemas.user_schemas import UserSchema +from zenml.zen_stores.schemas.utils import RunMetadataInterface from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema if TYPE_CHECKING: @@ -175,7 +176,7 @@ def update(self, artifact_update: ArtifactUpdate) -> "ArtifactSchema": return self -class ArtifactVersionSchema(BaseSchema, table=True): +class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True): """SQL Model for artifact versions.""" __tablename__ = "artifact_version" @@ -310,41 +311,6 @@ def from_request( save_type=artifact_version_request.save_type.value, ) - def fetch_metadata_collection(self) -> Dict[str, List[RunMetadataEntry]]: - """Fetches all the metadata entries related to the artifact version. - - Returns: - a dictionary, where the key is the key of the metadata entry - and the values represent the list of entries with this key. - """ - metadata_collection: Dict[str, List[RunMetadataEntry]] = {} - - # Fetch the metadata related to this step - for rm in self.run_metadata_resources: - if rm.run_metadata.key not in metadata_collection: - metadata_collection[rm.run_metadata.key] = [] - metadata_collection[rm.run_metadata.key].append( - RunMetadataEntry( - value=json.loads(rm.run_metadata.value), - created=rm.run_metadata.created, - ) - ) - - return metadata_collection - - def fetch_metadata(self) -> Dict[str, MetadataType]: - """Fetches the latest metadata entry related to the artifact version. - - Returns: - a dictionary, where the key is the key of the metadata entry - and the values represent the latest entry with this key. - """ - metadata_collection = self.fetch_metadata_collection() - return { - k: sorted(v, key=lambda x: x.created, reverse=True)[0].value - for k, v in metadata_collection.items() - } - def to_model( self, include_metadata: bool = False, diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index ce082e6f76a..2fc454bd020 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """SQLModel implementation of model tables.""" -import json from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast from uuid import UUID @@ -27,7 +26,6 @@ MetadataResourceTypes, TaggableResourceTypes, ) -from zenml.metadata.metadata_types import MetadataType from zenml.models import ( BaseResponseMetadata, ModelRequest, @@ -47,10 +45,12 @@ ModelVersionResponseMetadata, ModelVersionResponseResources, Page, - RunMetadataEntry, ) from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema -from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema +from zenml.zen_stores.schemas.base_schemas import ( + BaseSchema, + NamedSchema, +) from zenml.zen_stores.schemas.constants import MODEL_VERSION_TABLENAME from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema from zenml.zen_stores.schemas.run_metadata_schemas import ( @@ -59,7 +59,10 @@ from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema from zenml.zen_stores.schemas.user_schemas import UserSchema -from zenml.zen_stores.schemas.utils import get_page_from_list +from zenml.zen_stores.schemas.utils import ( + RunMetadataInterface, + get_page_from_list, +) from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema if TYPE_CHECKING: @@ -223,7 +226,7 @@ def update( return self -class ModelVersionSchema(NamedSchema, table=True): +class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): """SQL Model for model version.""" __tablename__ = MODEL_VERSION_TABLENAME @@ -348,41 +351,6 @@ def from_request( stage=model_version_request.stage, ) - def fetch_metadata_collection(self) -> Dict[str, List[RunMetadataEntry]]: - """Fetches all the metadata entries related to the model version. - - Returns: - a dictionary, where the key is the key of the metadata entry - and the values represent the list of entries with this key. - """ - metadata_collection: Dict[str, List[RunMetadataEntry]] = {} - - # Fetch the metadata related to this step - for rm in self.run_metadata_resources: - if rm.run_metadata.key not in metadata_collection: - metadata_collection[rm.run_metadata.key] = [] - metadata_collection[rm.run_metadata.key].append( - RunMetadataEntry( - value=json.loads(rm.run_metadata.value), - created=rm.run_metadata.created, - ) - ) - - return metadata_collection - - def fetch_metadata(self) -> Dict[str, MetadataType]: - """Fetches the latest metadata entry related to the model version. - - Returns: - a dictionary, where the key is the key of the metadata entry - and the values represent the latest entry with this key. - """ - metadata_collection = self.fetch_metadata_collection() - return { - k: sorted(v, key=lambda x: x.created, reverse=True)[0].value - for k, v in metadata_collection.items() - } - def to_model( self, include_metadata: bool = False, diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 02cae95044c..fa6954aa369 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -28,7 +28,6 @@ MetadataResourceTypes, TaggableResourceTypes, ) -from zenml.metadata.metadata_types import MetadataType from zenml.models import ( PipelineRunRequest, PipelineRunResponse, @@ -50,6 +49,7 @@ from zenml.zen_stores.schemas.stack_schemas import StackSchema from zenml.zen_stores.schemas.trigger_schemas import TriggerExecutionSchema from zenml.zen_stores.schemas.user_schemas import UserSchema +from zenml.zen_stores.schemas.utils import RunMetadataInterface from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema if TYPE_CHECKING: @@ -66,7 +66,7 @@ from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema -class PipelineRunSchema(NamedSchema, table=True): +class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): """SQL Model for pipeline runs.""" __tablename__ = "pipeline_run" @@ -260,18 +260,8 @@ def fetch_metadata_collection(self) -> Dict[str, List[RunMetadataEntry]]: a dictionary, where the key is the key of the metadata entry and the values represent the list of entries with this key. """ - metadata_collection: Dict[str, List[RunMetadataEntry]] = {} - # Fetch the metadata related to this run - for rm in self.run_metadata_resources: - if rm.run_metadata.key not in metadata_collection: - metadata_collection[rm.run_metadata.key] = [] - metadata_collection[rm.run_metadata.key].append( - RunMetadataEntry( - value=json.loads(rm.run_metadata.value), - created=rm.run_metadata.created, - ) - ) + metadata_collection = super().fetch_metadata_collection() # Fetch the metadata related to the steps of this run for s in self.step_runs: @@ -281,19 +271,6 @@ def fetch_metadata_collection(self) -> Dict[str, List[RunMetadataEntry]]: return metadata_collection - def fetch_metadata(self) -> Dict[str, MetadataType]: - """Fetches the latest metadata entry related to the pipeline run. - - Returns: - a dictionary, where the key is the key of the metadata entry - and the values represent the latest entry with this key. - """ - metadata_collection = self.fetch_metadata_collection() - return { - k: sorted(v, key=lambda x: x.created, reverse=True)[0].value - for k, v in metadata_collection.items() - } - def to_model( self, include_metadata: bool = False, diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index a00e8cdfe94..bc3077aea5e 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -30,9 +30,7 @@ MetadataResourceTypes, StepRunInputArtifactType, ) -from zenml.metadata.metadata_types import MetadataType from zenml.models import ( - RunMetadataEntry, StepRunRequest, StepRunResponse, StepRunResponseBody, @@ -44,7 +42,9 @@ StepRunInputResponse, StepRunResponseResources, ) -from zenml.zen_stores.schemas.base_schemas import NamedSchema +from zenml.zen_stores.schemas.base_schemas import ( + NamedSchema, +) from zenml.zen_stores.schemas.constants import MODEL_VERSION_TABLENAME from zenml.zen_stores.schemas.pipeline_deployment_schemas import ( PipelineDeploymentSchema, @@ -52,6 +52,7 @@ from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.user_schemas import UserSchema +from zenml.zen_stores.schemas.utils import RunMetadataInterface from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema if TYPE_CHECKING: @@ -63,7 +64,7 @@ ) -class StepRunSchema(NamedSchema, table=True): +class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): """SQL Model for steps of pipeline runs.""" __tablename__ = "step_run" @@ -203,41 +204,6 @@ def from_request(cls, request: StepRunRequest) -> "StepRunSchema": model_version_id=request.model_version_id, ) - def fetch_metadata_collection(self) -> Dict[str, List[RunMetadataEntry]]: - """Fetches all the metadata entries related to the step run. - - Returns: - a dictionary, where the key is the key of the metadata entry - and the values represent the list of entries with this key. - """ - metadata_collection: Dict[str, List[RunMetadataEntry]] = {} - - # Fetch the metadata related to this step - for rm in self.run_metadata_resources: - if rm.run_metadata.key not in metadata_collection: - metadata_collection[rm.run_metadata.key] = [] - metadata_collection[rm.run_metadata.key].append( - RunMetadataEntry( - value=json.loads(rm.run_metadata.value), - created=rm.run_metadata.created, - ) - ) - - return metadata_collection - - def fetch_metadata(self) -> Dict[str, MetadataType]: - """Fetches the latest metadata entry related to the step run. - - Returns: - a dictionary, where the key is the key of the metadata entry - and the values represent the latest entry with this key. - """ - metadata_collection = self.fetch_metadata_collection() - return { - k: sorted(v, key=lambda x: x.created, reverse=True)[0].value - for k, v in metadata_collection.items() - } - def to_model( self, include_metadata: bool = False, diff --git a/src/zenml/zen_stores/schemas/utils.py b/src/zenml/zen_stores/schemas/utils.py index 3c520d1eb86..5484a6a9cc8 100644 --- a/src/zenml/zen_stores/schemas/utils.py +++ b/src/zenml/zen_stores/schemas/utils.py @@ -13,10 +13,14 @@ # permissions and limitations under the License. """Utils for schemas.""" +import json import math -from typing import List, Type, TypeVar +from typing import Dict, List, Type, TypeVar -from zenml.models import BaseResponse, Page +from sqlmodel import Relationship + +from zenml.metadata.metadata_types import MetadataType +from zenml.models import BaseResponse, Page, RunMetadataEntry from zenml.zen_stores.schemas.base_schemas import BaseSchema S = TypeVar("S", bound=BaseSchema) @@ -66,3 +70,44 @@ def get_page_from_list( total=total, items=page_items, ) + + +class RunMetadataInterface: + """The interface for entities with run metadata.""" + + run_metadata_resources = Relationship() + + def fetch_metadata_collection(self) -> Dict[str, List[RunMetadataEntry]]: + """Fetches all the metadata entries related to the artifact version. + + Returns: + a dictionary, where the key is the key of the metadata entry + and the values represent the list of entries with this key. + """ + metadata_collection: Dict[str, List[RunMetadataEntry]] = {} + + # Fetch the metadata related to this step + for rm in self.run_metadata_resources: + if rm.run_metadata.key not in metadata_collection: + metadata_collection[rm.run_metadata.key] = [] + metadata_collection[rm.run_metadata.key].append( + RunMetadataEntry( + value=json.loads(rm.run_metadata.value), + created=rm.run_metadata.created, + ) + ) + + return metadata_collection + + def fetch_metadata(self) -> Dict[str, MetadataType]: + """Fetches the latest metadata entry related to the artifact version. + + Returns: + a dictionary, where the key is the key of the metadata entry + and the values represent the latest entry with this key. + """ + metadata_collection = self.fetch_metadata_collection() + return { + k: sorted(v, key=lambda x: x.created, reverse=True)[0].value + for k, v in metadata_collection.items() + } From 4512ac79a09bf7858a71175b87a24cf20adb1cc4 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 28 Nov 2024 10:58:59 +0100 Subject: [PATCH 091/124] formatting --- src/zenml/zen_stores/schemas/artifact_schemas.py | 5 +---- src/zenml/zen_stores/schemas/model_schemas.py | 5 +---- src/zenml/zen_stores/schemas/step_run_schemas.py | 4 +--- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index f2455dff2cf..15d448c92bb 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -41,10 +41,7 @@ ArtifactVersionUpdate, ) from zenml.models.v2.core.artifact import ArtifactRequest -from zenml.zen_stores.schemas.base_schemas import ( - BaseSchema, - NamedSchema, -) +from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema from zenml.zen_stores.schemas.component_schemas import StackComponentSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.step_run_schemas import ( diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 2fc454bd020..7e67c1cf2b1 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -47,10 +47,7 @@ Page, ) from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema -from zenml.zen_stores.schemas.base_schemas import ( - BaseSchema, - NamedSchema, -) +from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema from zenml.zen_stores.schemas.constants import MODEL_VERSION_TABLENAME from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema from zenml.zen_stores.schemas.run_metadata_schemas import ( diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index 88045c34e36..d7f13745312 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -43,9 +43,7 @@ StepRunInputResponse, StepRunResponseResources, ) -from zenml.zen_stores.schemas.base_schemas import ( - NamedSchema, -) +from zenml.zen_stores.schemas.base_schemas import NamedSchema from zenml.zen_stores.schemas.constants import MODEL_VERSION_TABLENAME from zenml.zen_stores.schemas.pipeline_deployment_schemas import ( PipelineDeploymentSchema, From ff695c47307bd64d92e7aefefe2b2d9457558983 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 28 Nov 2024 14:04:24 +0100 Subject: [PATCH 092/124] new functionalities --- src/zenml/artifacts/utils.py | 74 ++++++----- src/zenml/model/utils.py | 29 ++--- src/zenml/utils/metadata_utils.py | 204 ++++++++++++++++++------------ 3 files changed, 173 insertions(+), 134 deletions(-) diff --git a/src/zenml/artifacts/utils.py b/src/zenml/artifacts/utils.py index 34ce02849fa..f19c27387e5 100644 --- a/src/zenml/artifacts/utils.py +++ b/src/zenml/artifacts/utils.py @@ -14,6 +14,7 @@ """Utility functions for handling artifacts.""" import base64 +import contextlib import os import tempfile import zipfile @@ -30,6 +31,7 @@ ) from uuid import UUID, uuid4 +from zenml import log_metadata from zenml.artifacts.preexisting_data_materializer import ( PreexistingDataMaterializer, ) @@ -41,7 +43,6 @@ ArtifactSaveType, ArtifactType, ExecutionStatus, - MetadataResourceTypes, StackComponentType, VisualizationType, ) @@ -58,7 +59,6 @@ ArtifactVisualizationRequest, LoadedVisualization, PipelineRunResponse, - RunMetadataResource, StepRunResponse, StepRunUpdate, ) @@ -405,53 +405,61 @@ def log_artifact_metadata( artifact_version: The version of the artifact to log metadata for. If not provided, when being called inside a step that produces an artifact named `artifact_name`, the metadata will be associated to - the corresponding newly created artifact. Or, if not provided when - being called outside a step, or in a step that does not produce - any artifact named `artifact_name`, the metadata will be associated - to the latest version of that artifact. + the corresponding newly created artifact. Raises: ValueError: If no artifact name is provided and the function is not called inside a step with a single output, or, if neither an artifact nor an output with the given name exists. + """ logger.warning( "The `log_artifact_metadata` function is deprecated and will soon be " "removed. Please use `log_metadata` instead." ) - try: + + if artifact_name and artifact_version: + assert artifact_name is not None + + log_metadata( + metadata=metadata, + artifact_name=artifact_name, + artifact_version=artifact_version, + ) + + step_context = None + with contextlib.suppress(RuntimeError): step_context = get_step_context() - in_step_outputs = (artifact_name in step_context._outputs) or ( - not artifact_name and len(step_context._outputs) == 1 + + if step_context and artifact_name in step_context._outputs.keys(): + log_metadata( + metadata=metadata, + artifact_name=artifact_name, + infer_artifact=True, ) - except RuntimeError: - step_context = None - in_step_outputs = False - - if not step_context or not in_step_outputs or artifact_version: - if not artifact_name: - raise ValueError( - "Artifact name must be provided unless the function is called " - "inside a step with a single output." - ) + elif artifact_name: client = Client() - response = client.get_artifact_version(artifact_name, artifact_version) - client.create_run_metadata( + logger.warning( + "Deprecation warning! Currently, you are calling " + "`log_artifact_metadata` from a context, where we use the " + "`artifact_name` to fetch it and link the metadata to its " + "latest version. This behaviour is deprecated and will be " + "removed in the future. To circumvent this, please check" + "the `log_metadata` function." + ) + artifact_version_model = client.get_artifact_version( + name_id_or_prefix=artifact_name, version=artifact_version + ) + log_metadata( metadata=metadata, - resources=[ - RunMetadataResource( - id=response.id, type=MetadataResourceTypes.ARTIFACT_VERSION - ) - ], + artifact_version_id=artifact_version_model.id, ) - else: - try: - step_context.add_output_metadata( - metadata=metadata, output_name=artifact_name - ) - except StepContextError as e: - raise ValueError(e) + raise ValueError( + "You need to call `log_artifact_metadata` either within a step " + "(potentially with an artifact name) or outside of a step with an " + "artifact name (and/or version)." + ) # ----------------- diff --git a/src/zenml/model/utils.py b/src/zenml/model/utils.py index 6f83bd2bd60..b57156435e0 100644 --- a/src/zenml/model/utils.py +++ b/src/zenml/model/utils.py @@ -16,6 +16,7 @@ from typing import Dict, Optional, Union from uuid import UUID +from zenml import log_metadata from zenml.client import Client from zenml.enums import ModelStages from zenml.exceptions import StepContextError @@ -50,10 +51,6 @@ def log_model_metadata( model_version: The version of the model to log metadata for. Can be omitted when being called inside a step with configured `model` in decorator. - - Raises: - ValueError: If no model name/version is provided and the function is not - called inside a step with configured `model` in decorator. """ logger.warning( "The `log_model_metadata` function is deprecated and will soon be " @@ -61,20 +58,16 @@ def log_model_metadata( ) if model_name and model_version: - from zenml import Model - - mv = Model(name=model_name, version=model_version) + log_metadata( + metadata=metadata, + model_version=model_version, + model_name=model_name, + ) else: - try: - step_context = get_step_context() - except RuntimeError: - raise ValueError( - "Model name and version must be provided unless the function is " - "called inside a step with configured `model` in decorator." - ) - mv = step_context.model - - mv.log_metadata(metadata) + log_metadata( + metadata=metadata, + infer_model=True, + ) def link_artifact_version_to_model_version( @@ -107,7 +100,7 @@ def link_artifact_to_model( model: The model to link to. Raises: - RuntimeError: If called outside of a step. + RuntimeError: If called outside a step. """ if not model: is_issue = False diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 49991fda47b..6b08a8c19a2 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -13,12 +13,11 @@ # permissions and limitations under the License. """Utility functions to handle metadata for ZenML entities.""" -import contextlib from typing import Dict, Optional, Union, overload from uuid import UUID from zenml.client import Client -from zenml.enums import MetadataResourceTypes +from zenml.enums import MetadataResourceTypes, ModelStages from zenml.logger import get_logger from zenml.metadata.metadata_types import MetadataType from zenml.models import RunMetadataResource @@ -27,10 +26,17 @@ logger = get_logger(__name__) +@overload +def log_metadata( + metadata: Dict[str, MetadataType], +) -> None: ... + + @overload def log_metadata( *, metadata: Dict[str, MetadataType], + step_id: UUID, ) -> None: ... @@ -38,7 +44,8 @@ def log_metadata( def log_metadata( *, metadata: Dict[str, MetadataType], - artifact_version_id: UUID, + step_name: str, + run_id_name_or_prefix: Union[UUID, str], ) -> None: ... @@ -46,8 +53,7 @@ def log_metadata( def log_metadata( *, metadata: Dict[str, MetadataType], - artifact_name: str, - artifact_version: Optional[str] = None, + run_id_name_or_prefix: Union[UUID, str], ) -> None: ... @@ -55,7 +61,7 @@ def log_metadata( def log_metadata( *, metadata: Dict[str, MetadataType], - model_version_id: UUID, + artifact_version_id: UUID, ) -> None: ... @@ -63,8 +69,8 @@ def log_metadata( def log_metadata( *, metadata: Dict[str, MetadataType], - model_name: str, - model_version: str, + artifact_name: str, + artifact_version: Optional[str] = None, ) -> None: ... @@ -72,15 +78,17 @@ def log_metadata( def log_metadata( *, metadata: Dict[str, MetadataType], - step_id: UUID, + infer_artifact: bool = False, + artifact_name: Optional[str] = None, ) -> None: ... +# Model Metadata @overload def log_metadata( *, metadata: Dict[str, MetadataType], - run_id_name_or_prefix: Union[UUID, str], + model_version_id: UUID, ) -> None: ... @@ -88,25 +96,35 @@ def log_metadata( def log_metadata( *, metadata: Dict[str, MetadataType], - step_name: str, - run_id_name_or_prefix: Union[UUID, str], + model_name: str, + model_version: Union[ModelStages, int, str], ) -> None: ... +@overload def log_metadata( + *, metadata: Dict[str, MetadataType], - # Parameters to manually log metadata for steps and runs + infer_model: bool = False, +) -> None: ... + + +def log_metadata( + metadata: Dict[str, MetadataType], + # Steps and runs step_id: Optional[UUID] = None, step_name: Optional[str] = None, run_id_name_or_prefix: Optional[Union[UUID, str]] = None, - # Parameters to manually log metadata for artifacts + # Artifacts artifact_version_id: Optional[UUID] = None, artifact_name: Optional[str] = None, artifact_version: Optional[str] = None, - # Parameters to manually log metadata for models + infer_artifact: Optional[bool] = None, + # Models model_version_id: Optional[UUID] = None, model_name: Optional[str] = None, - model_version: Optional[str] = None, + model_version: Optional[Union[ModelStages, int, str]] = None, + infer_model: Optional[bool] = None, ) -> None: """Logs metadata for various resource types in a generalized way. @@ -118,9 +136,13 @@ def log_metadata( artifact_version_id: The ID of the artifact version artifact_name: The name of the artifact. artifact_version: The version of the artifact. + infer_artifact: Flag deciding whether the artifact version should be + inferred from the step context. model_version_id: The ID of the model version. model_name: The name of the model. model_version: The version of the model. + infer_model: Flag deciding whether the model version should be + inferred from the step context. Raises: ValueError: If no identifiers are provided and the function is not @@ -128,29 +150,29 @@ def log_metadata( """ client = Client() - # Log metadata to a step by name and run ID - if step_name is not None and run_id_name_or_prefix is not None: - step_model_id = ( - client.get_pipeline_run(name_id_or_prefix=run_id_name_or_prefix) - .steps[step_name] - .id - ) + # Log metadata to a step by ID + if step_id is not None: client.create_run_metadata( metadata=metadata, resources=[ RunMetadataResource( - id=step_model_id, type=MetadataResourceTypes.STEP_RUN + id=step_id, type=MetadataResourceTypes.STEP_RUN ) ], ) - # Log metadata to a step by ID - elif step_id is not None: + # Log metadata to a step by name and run ID + elif step_name is not None and run_id_name_or_prefix is not None: + step_model_id = ( + client.get_pipeline_run(name_id_or_prefix=run_id_name_or_prefix) + .steps[step_name] + .id + ) client.create_run_metadata( metadata=metadata, resources=[ RunMetadataResource( - id=step_id, type=MetadataResourceTypes.STEP_RUN + id=step_model_id, type=MetadataResourceTypes.STEP_RUN ) ], ) @@ -174,15 +196,7 @@ def log_metadata( from zenml import Model mv = Model(name=model_name, version=model_version) - - client.create_run_metadata( - metadata=metadata, - resources=[ - RunMetadataResource( - id=mv.id, type=MetadataResourceTypes.MODEL_VERSION - ) - ], - ) + mv.log_metadata(metadata) # Log metadata to a model version by id elif model_version_id is not None: @@ -196,51 +210,36 @@ def log_metadata( ], ) - # If the user provides an artifact name, there are three possibilities. If - # an artifact version is also provided with the name, we use both to fetch - # the artifact version and use it to log the metadata. If no version is - # provided, if the function is called within a step we search the artifacts - # of the step if not we fetch the latest version and attach the metadata - # to the latest version. - elif artifact_name is not None: - if artifact_version: - artifact_version_model = client.get_artifact_version( - name_id_or_prefix=artifact_name, version=artifact_version - ) - client.create_run_metadata( - metadata=metadata, - resources=[ - RunMetadataResource( - id=artifact_version_model.id, - type=MetadataResourceTypes.ARTIFACT_VERSION, - ) - ], + # Log metadata to a model through the step context + elif infer_model is True: + try: + step_context = get_step_context() + except RuntimeError: + raise ValueError( + "If you are using the `infer_model` option, the function must " + "be called inside a step with configured `model` in decorator." + "Otherwise, you can provide a `model_version_id` or a " + "combination of `model_name` and `model_version`." ) - else: - step_context = None - with contextlib.suppress(RuntimeError): - step_context = get_step_context() + mv = step_context.model + mv.log_metadata(metadata) - if step_context and artifact_name in step_context._outputs: - step_context.add_output_metadata( - metadata=metadata, output_name=artifact_name - ) - else: - artifact_version_model = client.get_artifact_version( - name_id_or_prefix=artifact_name - ) - client.create_run_metadata( - metadata=metadata, - resources=[ - RunMetadataResource( - id=artifact_version_model.id, - type=MetadataResourceTypes.ARTIFACT_VERSION, - ) - ], + # Log metadata to an artifact version by its name and version + elif artifact_name is not None and artifact_version is not None: + artifact_version_model = client.get_artifact_version( + name_id_or_prefix=artifact_name, version=artifact_version + ) + client.create_run_metadata( + metadata=metadata, + resources=[ + RunMetadataResource( + id=artifact_version_model.id, + type=MetadataResourceTypes.ARTIFACT_VERSION, ) + ], + ) - # If the user directly provides an artifact_version_id, we use the client to - # fetch is and attach the metadata accordingly. + # Log metadata to an artifact version by its ID elif artifact_version_id is not None: client.create_run_metadata( metadata=metadata, @@ -252,6 +251,39 @@ def log_metadata( ], ) + # Log metadata to an artifact version through the step context + elif infer_artifact is True: + try: + step_context = get_step_context() + except RuntimeError: + raise ValueError( + "When you are using the `infer_artifact` option when you call " + "`log_metadata`, it must be called inside a step with outputs." + "Otherwise, you can provide a `artifact_version_id` or a " + "combination of `artifact_name` and `artifact_version`." + ) + + step_output_names = list(step_context._outputs.keys()) + + if artifact_name is not None: + # If a name provided, ensure it is in the outputs + assert artifact_name in step_output_names, ( + f"The provided `artifact_name` does not exist in the " + f"step outputs: {step_output_names}." + ) + else: + # If no name provided, ensure there is only one output + assert len(step_output_names) == 1, ( + "There is mode than one output. If you would like to use the " + "`infer_artifact` option, you need to define an artifact_name." + ) + + artifact_name = step_output_names[0] + + step_context.add_output_metadata( + metadata=metadata, output_name=artifact_name + ) + # If every additional value is None, that means we are calling it bare bones # and this call needs to happen during a step execution. We will use the # step context to fetch the step, run and possibly the model version and @@ -296,22 +328,28 @@ def log_metadata( Unsupported way to call the `log_metadata`. Possible combinations " include: - # Inside a step + # Automatic logging to a step (within a step) log_metadata(metadata={}) - # Manual logging for a step + # Manual logging to a step log_metadata(metadata={}, step_name=..., run_id_name_or_prefix=...) log_metadata(metadata={}, step_id=...) - # Manual logging for a run + # Manual logging to a run log_metadata(metadata={}, run_id_name_or_prefix=...) - # Manual logging for a model + # Automatic logging to a model (within a step) + log_metadata(metadata={}, infer_model=True) + + # Manual logging to a model log_metadata(metadata={}, model_name=..., model_version=...) log_metadata(metadata={}, model_version_id=...) - # Manual logging for an artifact - log_metadata(metadata={}, artifact_name=...) # inside a step + # Automatic logging to an artifact (within a step) + log_metadata(metadata={}, infer_artifact=True) # step with single output + log_metadata(metadata={}, artifact_name=..., infer_artifact=True) # specific output of a step + + # Manual logging to an artifact log_metadata(metadata={}, artifact_name=..., artifact_version=...) log_metadata(metadata={}, artifact_version_id=...) """ From 5e34213d63d39299c428dbcd206d565adbded310 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 28 Nov 2024 14:09:24 +0100 Subject: [PATCH 093/124] breaking circular imports --- src/zenml/artifacts/utils.py | 3 ++- src/zenml/model/utils.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/zenml/artifacts/utils.py b/src/zenml/artifacts/utils.py index f19c27387e5..73d8fcb13e9 100644 --- a/src/zenml/artifacts/utils.py +++ b/src/zenml/artifacts/utils.py @@ -31,7 +31,6 @@ ) from uuid import UUID, uuid4 -from zenml import log_metadata from zenml.artifacts.preexisting_data_materializer import ( PreexistingDataMaterializer, ) @@ -418,6 +417,8 @@ def log_artifact_metadata( "removed. Please use `log_metadata` instead." ) + from zenml import log_metadata + if artifact_name and artifact_version: assert artifact_name is not None diff --git a/src/zenml/model/utils.py b/src/zenml/model/utils.py index b57156435e0..93b552a8679 100644 --- a/src/zenml/model/utils.py +++ b/src/zenml/model/utils.py @@ -16,7 +16,6 @@ from typing import Dict, Optional, Union from uuid import UUID -from zenml import log_metadata from zenml.client import Client from zenml.enums import ModelStages from zenml.exceptions import StepContextError @@ -57,6 +56,8 @@ def log_model_metadata( "removed. Please use `log_metadata` instead." ) + from zenml import log_metadata + if model_name and model_version: log_metadata( metadata=metadata, From be6ea07805358aab50a752c0011723fdbf3bb8bf Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 28 Nov 2024 14:11:27 +0100 Subject: [PATCH 094/124] spellchecker --- src/zenml/artifacts/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/artifacts/utils.py b/src/zenml/artifacts/utils.py index 73d8fcb13e9..138b3f53699 100644 --- a/src/zenml/artifacts/utils.py +++ b/src/zenml/artifacts/utils.py @@ -444,7 +444,7 @@ def log_artifact_metadata( "Deprecation warning! Currently, you are calling " "`log_artifact_metadata` from a context, where we use the " "`artifact_name` to fetch it and link the metadata to its " - "latest version. This behaviour is deprecated and will be " + "latest version. This behavior is deprecated and will be " "removed in the future. To circumvent this, please check" "the `log_metadata` function." ) From ee68112580333639d4cf7c0291b70f153f18b9e9 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 28 Nov 2024 14:27:18 +0100 Subject: [PATCH 095/124] other minor fixes --- src/zenml/model/utils.py | 10 +++++++++- src/zenml/models/v2/core/run_metadata.py | 3 +++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/zenml/model/utils.py b/src/zenml/model/utils.py index 93b552a8679..a3612fc2c12 100644 --- a/src/zenml/model/utils.py +++ b/src/zenml/model/utils.py @@ -50,6 +50,9 @@ def log_model_metadata( model_version: The version of the model to log metadata for. Can be omitted when being called inside a step with configured `model` in decorator. + + Raises: + ValueError: If the function is not called with proper input. """ logger.warning( "The `log_model_metadata` function is deprecated and will soon be " @@ -64,11 +67,16 @@ def log_model_metadata( model_version=model_version, model_name=model_name, ) - else: + elif model_name is None and model_version is None: log_metadata( metadata=metadata, infer_model=True, ) + else: + raise ValueError( + "You can call `log_model_metadata` by either providing both " + "`model_name` and `model_version` or keeping both of them None." + ) def link_artifact_version_to_model_version( diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index 9579f202def..75d3b00427e 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -55,6 +55,9 @@ def validate_values_keys(self) -> "RunMetadataRequest": Returns: self + + Raises: + ValueError: if one of the key in the metadata contains `:` """ invalid_keys = [key for key in self.values.keys() if ":" in key] if invalid_keys: From 1b8a6f1c82ca091b1871c1271500bac3c953319e Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 28 Nov 2024 14:53:35 +0100 Subject: [PATCH 096/124] covering the uncovered case --- src/zenml/artifacts/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/zenml/artifacts/utils.py b/src/zenml/artifacts/utils.py index 138b3f53699..6a264ff9f3e 100644 --- a/src/zenml/artifacts/utils.py +++ b/src/zenml/artifacts/utils.py @@ -438,6 +438,14 @@ def log_artifact_metadata( artifact_name=artifact_name, infer_artifact=True, ) + elif step_context and len(step_context._outputs) == 1: + single_output_name = list(step_context._outputs.keys())[0] + + log_metadata( + metadata=metadata, + artifact_name=single_output_name, + infer_artifact=True, + ) elif artifact_name: client = Client() logger.warning( From f343acda7676c4f7eff4afdf51b7876b311fd445 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 28 Nov 2024 14:59:30 +0100 Subject: [PATCH 097/124] adjusting tests --- tests/integration/functional/artifacts/test_utils.py | 6 +++++- .../functional/pipelines/test_pipeline_context.py | 2 +- tests/integration/functional/steps/test_step_context.py | 3 ++- tests/integration/functional/test_client.py | 2 +- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/integration/functional/artifacts/test_utils.py b/tests/integration/functional/artifacts/test_utils.py index 5c091e61fb0..02092a33870 100644 --- a/tests/integration/functional/artifacts/test_utils.py +++ b/tests/integration/functional/artifacts/test_utils.py @@ -215,7 +215,11 @@ def artifact_multi_output_metadata_logging_step() -> ( "description": "Blupus is great!", "metrics": {"accuracy": 0.9}, } - log_metadata(metadata=output_metadata, artifact_name="int_output") + log_metadata( + metadata=output_metadata, + artifact_name="int_output", + infer_artifact=True, + ) return "42", 42 diff --git a/tests/integration/functional/pipelines/test_pipeline_context.py b/tests/integration/functional/pipelines/test_pipeline_context.py index 70e7608f7a8..f070ca0272f 100644 --- a/tests/integration/functional/pipelines/test_pipeline_context.py +++ b/tests/integration/functional/pipelines/test_pipeline_context.py @@ -109,7 +109,7 @@ def producer() -> Annotated[str, "bar"]: ) log_metadata( metadata={"foobar": "artifact_meta_" + model.version}, - artifact_name="bar", + infer_artifact=True, ) return "artifact_data_" + model.version diff --git a/tests/integration/functional/steps/test_step_context.py b/tests/integration/functional/steps/test_step_context.py index d520cfd83a4..f34e53b8e4b 100644 --- a/tests/integration/functional/steps/test_step_context.py +++ b/tests/integration/functional/steps/test_step_context.py @@ -93,7 +93,8 @@ def _simple_step_pipeline(): @step def output_metadata_logging_step() -> Annotated[int, "my_output"]: log_metadata( - metadata={"some_key": "some_value"}, artifact_name="my_output" + metadata={"some_key": "some_value"}, + infer_artifact=True, ) return 42 diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index c9f0cb3554a..bb2cc4bc8e5 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -979,7 +979,7 @@ def lazy_producer_test_artifact() -> Annotated[str, "new_one"]: from zenml.client import Client log_metadata( - metadata={"some_meta": "meta_new_one"}, artifact_name="new_one" + metadata={"some_meta": "meta_new_one"}, infer_artifact=True, ) client = Client() From a36ab4918504526f3c7421456407662cec0838c4 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 28 Nov 2024 15:00:41 +0100 Subject: [PATCH 098/124] fixing the quickstart again --- examples/quickstart/steps/model_evaluator.py | 11 ++--------- examples/quickstart/steps/model_tester.py | 14 +++++--------- tests/integration/functional/test_client.py | 3 ++- 3 files changed, 9 insertions(+), 19 deletions(-) diff --git a/examples/quickstart/steps/model_evaluator.py b/examples/quickstart/steps/model_evaluator.py index fc8dac00132..4ae2e979396 100644 --- a/examples/quickstart/steps/model_evaluator.py +++ b/examples/quickstart/steps/model_evaluator.py @@ -20,7 +20,7 @@ T5ForConditionalGeneration, ) -from zenml import get_step_context, log_metadata, step +from zenml import log_metadata, step from zenml.logger import get_logger logger = get_logger(__name__) @@ -50,11 +50,4 @@ def evaluate_model( avg_loss = total_loss / num_batches print(f"Average loss on the dataset: {avg_loss}") - step_context = get_step_context() - - if step_context.model: - log_metadata( - metadata={"Average Loss": avg_loss}, - model_name=step_context.model.name, - model_version=step_context.model.version, - ) + log_metadata(metadata={"Average Loss": avg_loss}, infer_model=True) diff --git a/examples/quickstart/steps/model_tester.py b/examples/quickstart/steps/model_tester.py index 72d68ed7d57..93e261b7ef4 100644 --- a/examples/quickstart/steps/model_tester.py +++ b/examples/quickstart/steps/model_tester.py @@ -21,7 +21,7 @@ T5TokenizerFast, ) -from zenml import get_step_context, log_metadata, step +from zenml import log_metadata, step from zenml.logger import get_logger from .data_loader import PROMPT @@ -70,11 +70,7 @@ def test_model( sentence_without_prompt: decoded_output } - step_context = get_step_context() - - if step_context.model: - log_metadata( - metadata={"Example Prompts": test_collection}, - model_name=step_context.model.name, - model_version=step_context.model.version, - ) + log_metadata( + metadata={"Example Prompts": test_collection}, + infer_model=True, + ) diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index bb2cc4bc8e5..9afbb3b9b14 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -979,7 +979,8 @@ def lazy_producer_test_artifact() -> Annotated[str, "new_one"]: from zenml.client import Client log_metadata( - metadata={"some_meta": "meta_new_one"}, infer_artifact=True, + metadata={"some_meta": "meta_new_one"}, + infer_artifact=True, ) client = Client() From 77b2310832ab31959697afb2c38a967eb12c8a2d Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 28 Nov 2024 15:17:25 +0100 Subject: [PATCH 099/124] minor change --- src/zenml/artifacts/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/artifacts/utils.py b/src/zenml/artifacts/utils.py index 6a264ff9f3e..2573964aa77 100644 --- a/src/zenml/artifacts/utils.py +++ b/src/zenml/artifacts/utils.py @@ -457,7 +457,7 @@ def log_artifact_metadata( "the `log_metadata` function." ) artifact_version_model = client.get_artifact_version( - name_id_or_prefix=artifact_name, version=artifact_version + name_id_or_prefix=artifact_name ) log_metadata( metadata=metadata, From 94c26fd4cd2e790c6d5b2816bfafd0777b31d006 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 28 Nov 2024 15:51:13 +0100 Subject: [PATCH 100/124] going back to publisher step id --- src/zenml/client.py | 8 ++++---- src/zenml/models/v2/core/run_metadata.py | 7 +++---- src/zenml/utils/metadata_utils.py | 2 +- .../versions/cc269488e5a9_separate_run_metadata.py | 4 ++-- src/zenml/zen_stores/schemas/run_metadata_schemas.py | 10 +++++++++- src/zenml/zen_stores/sql_zen_store.py | 5 +++-- 6 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index cf11f646db5..755161345fe 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -4440,7 +4440,7 @@ def create_run_metadata( metadata: Dict[str, "MetadataType"], resources: List[RunMetadataResource], stack_component_id: Optional[UUID] = None, - cached: bool = False, + publisher_step_id: Optional[UUID] = None, ) -> None: """Create run metadata. @@ -4450,8 +4450,8 @@ def create_run_metadata( metadata was produced. stack_component_id: The ID of the stack component that produced the metadata. - cached: A flag indicating if the run metadata can be cached during - a step execution. + publisher_step_id: The ID of the step execution that publishes + this metadata automatically. """ from zenml.metadata.metadata_types import get_metadata_type @@ -4482,7 +4482,7 @@ def create_run_metadata( user=self.active_user.id, resources=resources, stack_component_id=stack_component_id, - cached=cached, + publisher_step_id=publisher_step_id, values=values, types=types, ) diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index 75d3b00427e..5822451357d 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -43,10 +43,9 @@ class RunMetadataRequest(WorkspaceScopedRequest): types: Dict[str, "MetadataTypeEnum"] = Field( title="The types of the metadata to be created.", ) - cached: Optional[bool] = Field( - title="A flag indicating if the run metadata is cached through " - "a step execution.", - default=False, + publisher_step_id: Optional[UUID] = Field( + title="The ID of the step execution that published this metadata.", + default=None, ) @model_validator(mode="after") diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 6b08a8c19a2..80b27e6af10 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -319,7 +319,7 @@ def log_metadata( type=MetadataResourceTypes.STEP_RUN, ) ], - cached=True, + publisher_step_id=step_context.step_run.id, ) else: diff --git a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py index 922c0f14b41..ca254308ea6 100644 --- a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py +++ b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py @@ -90,7 +90,7 @@ def upgrade() -> None: op.add_column( "run_metadata", sa.Column( - "cached", sa.Boolean(), nullable=True, server_default=sa.false() + "publisher_step_id", sqlmodel.sql.sqltypes.GUID(), nullable=True ), ) @@ -137,4 +137,4 @@ def downgrade() -> None: op.drop_table("run_metadata_resource") # Drop the cached column - op.drop_column("run_metadata", "cached") + op.drop_column("run_metadata", "publisher_step_id") diff --git a/src/zenml/zen_stores/schemas/run_metadata_schemas.py b/src/zenml/zen_stores/schemas/run_metadata_schemas.py index 8c61ff5c98a..b8945cf60e6 100644 --- a/src/zenml/zen_stores/schemas/run_metadata_schemas.py +++ b/src/zenml/zen_stores/schemas/run_metadata_schemas.py @@ -78,7 +78,15 @@ class RunMetadataSchema(BaseSchema, table=True): key: str value: str = Field(sa_column=Column(TEXT, nullable=False)) type: str - cached: Optional[bool] = Field(default=False) + + publisher_step_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=StepRunSchema.__tablename__, + source_column="publisher_step_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) class RunMetadataResourceSchema(SQLModel, table=True): diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 44da555fd14..5f44873e87b 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5572,7 +5572,7 @@ def create_run_metadata(self, run_metadata: RunMetadataRequest) -> None: key=key, value=json.dumps(value), type=type_, - cached=run_metadata.cached, + publisher_step_id=run_metadata.publisher_step_id, ) session.add(run_metadata_schema) session.commit() @@ -8215,7 +8215,8 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse: == MetadataResourceTypes.STEP_RUN ) .where( - RunMetadataSchema.cached == True # noqa: E712 + RunMetadataSchema.publisher_step_id + == step_run.original_step_run_id ) ).all() From aabb2ba44e8147fd4907744ee273570461241c1c Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 28 Nov 2024 15:52:15 +0100 Subject: [PATCH 101/124] updating github refs --- .github/workflows/update-templates-to-examples.yml | 8 ++++---- src/zenml/cli/base.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/update-templates-to-examples.yml b/.github/workflows/update-templates-to-examples.yml index f58b2a9424f..cb6a8ff6a93 100644 --- a/.github/workflows/update-templates-to-examples.yml +++ b/.github/workflows/update-templates-to-examples.yml @@ -46,7 +46,7 @@ jobs: python-version: ${{ inputs.python-version }} stack-name: local ref-zenml: ${{ github.ref }} - ref-template: 2024.11.13 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py + ref-template: 2024.11.28 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py - name: Clean-up run: | rm -rf ./local_checkout @@ -118,7 +118,7 @@ jobs: python-version: ${{ inputs.python-version }} stack-name: local ref-zenml: ${{ github.ref }} - ref-template: 2024.11.13 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py + ref-template: 2024.11.28 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py - name: Clean-up run: | rm -rf ./local_checkout @@ -189,7 +189,7 @@ jobs: python-version: ${{ inputs.python-version }} stack-name: local ref-zenml: ${{ github.ref }} - ref-template: 2024.11.13 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py + ref-template: 2024.11.28 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py - name: Clean-up run: | rm -rf ./local_checkout @@ -261,7 +261,7 @@ jobs: with: python-version: ${{ inputs.python-version }} ref-zenml: ${{ github.ref }} - ref-template: 2024.11.13 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py + ref-template: 2024.11.28 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py - name: Clean-up run: | rm -rf ./local_checkout diff --git a/src/zenml/cli/base.py b/src/zenml/cli/base.py index 1d5adecae0b..8bc22c45446 100644 --- a/src/zenml/cli/base.py +++ b/src/zenml/cli/base.py @@ -79,19 +79,19 @@ def copier_github_url(self) -> str: ZENML_PROJECT_TEMPLATES = dict( e2e_batch=ZenMLProjectTemplateLocation( github_url="zenml-io/template-e2e-batch", - github_tag="2024.11.13", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + github_tag="2024.11.28", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), starter=ZenMLProjectTemplateLocation( github_url="zenml-io/template-starter", - github_tag="2024.11.13", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + github_tag="2024.11.28", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), nlp=ZenMLProjectTemplateLocation( github_url="zenml-io/template-nlp", - github_tag="2024.11.13", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + github_tag="2024.11.28", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), llm_finetuning=ZenMLProjectTemplateLocation( github_url="zenml-io/template-llm-finetuning", - github_tag="2024.11.13", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + github_tag="2024.11.28", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), ) From dc1df139014d9bb907cc226c4260d85def562d04 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Thu, 28 Nov 2024 14:57:03 +0000 Subject: [PATCH 102/124] Auto-update of LLM Finetuning template --- examples/llm_finetuning/.copier-answers.yml | 2 +- .../llm_finetuning/steps/prepare_datasets.py | 19 ++++++++----------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/examples/llm_finetuning/.copier-answers.yml b/examples/llm_finetuning/.copier-answers.yml index 250f3b832e8..7deecebb1d2 100644 --- a/examples/llm_finetuning/.copier-answers.yml +++ b/examples/llm_finetuning/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.11.08-1-gd399790 +_commit: 2024.11.08-2-gece1d46 _src_path: gh:zenml-io/template-llm-finetuning bf16: true cuda_version: cuda11.8 diff --git a/examples/llm_finetuning/steps/prepare_datasets.py b/examples/llm_finetuning/steps/prepare_datasets.py index 00086bcdaf8..b9cc13c2261 100644 --- a/examples/llm_finetuning/steps/prepare_datasets.py +++ b/examples/llm_finetuning/steps/prepare_datasets.py @@ -22,7 +22,7 @@ from typing_extensions import Annotated from utils.tokenizer import generate_and_tokenize_prompt, load_tokenizer -from zenml import get_step_context, log_metadata, step +from zenml import log_metadata, step from zenml.materializers import BuiltInMaterializer from zenml.utils.cuda_utils import cleanup_gpu_memory @@ -49,16 +49,13 @@ def prepare_data( cleanup_gpu_memory(force=True) - context = get_step_context() - if context.model: - log_metadata( - metadata={ - "system_prompt": system_prompt, - "base_model_id": base_model_id, - }, - model_name=context.model.name, - model_version=context.model.version, - ) + log_metadata( + metadata={ + "system_prompt": system_prompt, + "base_model_id": base_model_id, + }, + infer_model=True, + ) tokenizer = load_tokenizer(base_model_id, False, use_fast) gen_and_tokenize = partial( From 0a7e26a23535ed917629b4052905fe18489cf6ea Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Thu, 28 Nov 2024 14:57:21 +0000 Subject: [PATCH 103/124] Auto-update of Starter template --- examples/mlops_starter/.copier-answers.yml | 2 +- examples/mlops_starter/steps/data_preprocessor.py | 1 + examples/mlops_starter/steps/model_evaluator.py | 15 ++++++++++----- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/examples/mlops_starter/.copier-answers.yml b/examples/mlops_starter/.copier-answers.yml index 1c65d17e37c..364bccaa9d0 100644 --- a/examples/mlops_starter/.copier-answers.yml +++ b/examples/mlops_starter/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.10.30-3-g52bf387 +_commit: 2024.10.30-7-gb60e441 _src_path: gh:zenml-io/template-starter email: info@zenml.io full_name: ZenML GmbH diff --git a/examples/mlops_starter/steps/data_preprocessor.py b/examples/mlops_starter/steps/data_preprocessor.py index f20cd93aa13..f94d1e85f6d 100644 --- a/examples/mlops_starter/steps/data_preprocessor.py +++ b/examples/mlops_starter/steps/data_preprocessor.py @@ -90,5 +90,6 @@ def data_preprocessor( log_metadata( metadata={"random_state": random_state, "target": target}, artifact_name="preprocess_pipeline", + infer_artifact=True, ) return dataset_trn, dataset_tst, preprocess_pipeline diff --git a/examples/mlops_starter/steps/model_evaluator.py b/examples/mlops_starter/steps/model_evaluator.py index a771d2fdd76..c63c53109f4 100644 --- a/examples/mlops_starter/steps/model_evaluator.py +++ b/examples/mlops_starter/steps/model_evaluator.py @@ -21,6 +21,7 @@ from sklearn.base import ClassifierMixin from zenml import log_metadata, step +from zenml.client import Client from zenml.logger import get_logger logger = get_logger(__name__) @@ -79,27 +80,31 @@ def model_evaluator( dataset_tst.drop(columns=[target]), dataset_tst[target], ) - logger.info(f"Train accuracy={trn_acc*100:.2f}%") - logger.info(f"Test accuracy={tst_acc*100:.2f}%") + logger.info(f"Train accuracy={trn_acc * 100:.2f}%") + logger.info(f"Test accuracy={tst_acc * 100:.2f}%") messages = [] if trn_acc < min_train_accuracy: messages.append( - f"Train accuracy {trn_acc*100:.2f}% is below {min_train_accuracy*100:.2f}% !" + f"Train accuracy {trn_acc * 100:.2f}% is below {min_train_accuracy * 100:.2f}% !" ) if tst_acc < min_test_accuracy: messages.append( - f"Test accuracy {tst_acc*100:.2f}% is below {min_test_accuracy*100:.2f}% !" + f"Test accuracy {tst_acc * 100:.2f}% is below {min_test_accuracy * 100:.2f}% !" ) else: for message in messages: logger.warning(message) + client = Client() + latest_classifier = client.get_artifact_version("sklearn_classifier") + log_metadata( metadata={ "train_accuracy": float(trn_acc), "test_accuracy": float(tst_acc), }, - artifact_name="sklearn_classifier", + artifact_version_id=latest_classifier.id, ) + return float(tst_acc) From d6de5cf5a969406742e7e8b97e60f7ffa3124aba Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 28 Nov 2024 16:01:40 +0100 Subject: [PATCH 104/124] fixing tests --- tests/integration/functional/artifacts/test_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/integration/functional/artifacts/test_utils.py b/tests/integration/functional/artifacts/test_utils.py index 02092a33870..c7b4a664f54 100644 --- a/tests/integration/functional/artifacts/test_utils.py +++ b/tests/integration/functional/artifacts/test_utils.py @@ -123,15 +123,16 @@ def _load_pipeline(expected_value, name, version): def test_log_metadata_existing(clean_client): """Test logging artifact metadata for existing artifacts.""" - save_artifact(42, "meaning_of_life") + av = save_artifact(42, "meaning_of_life") log_metadata( metadata={"description": "Aria is great!"}, - artifact_name="meaning_of_life", + artifact_version_id=av.id, ) save_artifact(43, "meaning_of_life", version="43") log_metadata( metadata={"description_2": "Blupus is great!"}, artifact_name="meaning_of_life", + artifact_version="43" ) log_metadata( metadata={"description_3": "Axl is great!"}, From 2a5feb26c17f627eb0890e882d240c1d458f9f7f Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 28 Nov 2024 16:11:54 +0100 Subject: [PATCH 105/124] updated docs --- .../attach-metadata-to-a-model.md | 38 ++++++++----------- .../attach-metadata-to-a-run.md | 2 +- .../attach-metadata-to-an-artifact.md | 29 ++++++-------- .../logging-metadata.md | 1 - src/zenml/utils/metadata_utils.py | 4 +- 5 files changed, 30 insertions(+), 44 deletions(-) diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md index a1c22cb8872..05bd97d5529 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md @@ -37,20 +37,16 @@ def train_model(dataset: pd.DataFrame) -> Annotated[ classifier = RandomForestClassifier().fit(dataset) accuracy, precision, recall = ... - step_context = get_step_context() - - if step_context.model: - # Log metadata for the model - log_metadata( - metadata={ - "evaluation_metrics": { - "accuracy": accuracy, - "precision": precision, - "recall": recall - } - }, - model_version_id=step_context.model.id, - ) + log_metadata( + metadata={ + "evaluation_metrics": { + "accuracy": accuracy, + "precision": precision, + "recall": recall + } + }, + infer_model=True, + ) return classifier ``` @@ -60,21 +56,17 @@ specific classifier artifact. This is particularly useful when the metadata reflects an aggregation or summary of various steps and artifacts in the pipeline. -{% hint style="info" %} -You can use the `get_step_context()` function to get fetch the model and model -version that the step is using. -{% endhint %} ### Selecting Models with `log_metadata` -When using `log_metadata` with a model, ZenML provides flexible options to -attach metadata accurately: +When using `log_metadata`, ZenML provides flexible options of attaching +metadata to model versions: -1. **Model Name and Version Provided**: If both a model name and version are +1. **Using `infer_model`**: If used within a step, ZenML will use the step + context to infer the model it is using and attach the metadata to it. +2. **Model Name and Version Provided**: If both a model name and version are provided, ZenML will use these to identify and attach metadata to the specific model version. -2. **Model Name Only**: If only a model name is provided, ZenML will attach - metadata to the latest version of the model. 3. **Model Version ID Provided**: If a model version ID is directly provided, ZenML will use it to fetch and attach the metadata to that specific model version. diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md index ca2aa34844a..e04a0c9006f 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md @@ -14,7 +14,7 @@ custom types like `Uri`, `Path`, `DType`, and `StorageSize`. If you are logging metadata from within a step that’s part of a pipeline run, calling `log_metadata` will attach the specified metadata to the current -pipeline run where the metadata key will have the `step_name:metadata_key` +pipeline run where the metadata key will have the `step_name::metadata_key` pattern. This allows you to use the same metadata key from different steps while the run's still executing. diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md index 5f2f962f80b..7f57ac1c18a 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md @@ -26,8 +26,6 @@ value, including ZenML custom types like `Uri`, `Path`, `DType`, and Here's an example of logging metadata for an artifact: ```python -from typing import Annotated - import pandas as pd from zenml import step, log_metadata @@ -35,21 +33,19 @@ from zenml.metadata.metadata_types import StorageSize @step -def process_data_step(dataframe: pd.DataFrame) -> Annotated[ - pd.DataFrame, "processed_data" -]: +def process_data_step(dataframe: pd.DataFrame) -> pd.DataFrame: """Process a dataframe and log metadata about the result.""" processed_dataframe = ... # Log metadata about the processed dataframe log_metadata( - artifact_name="processed_data", metadata={ "row_count": len(processed_dataframe), "columns": list(processed_dataframe.columns), "storage_size": StorageSize( processed_dataframe.memory_usage().sum()) - } + }, + infer_artifact=True, ) return processed_dataframe ``` @@ -59,17 +55,15 @@ def process_data_step(dataframe: pd.DataFrame) -> Annotated[ When using `log_metadata` with an artifact name, ZenML provides flexible options to attach metadata to the correct artifact: -1. **Name and Version Provided**: If both an artifact name and version are +1. **Using `infer_artifact`**: If used within a step, ZenML will use the step +context to infer the outputs artifacts of the step. If the step has only one +output, this artifact will be selected. However, if you additionally +provide an `artifact_name`, ZenML will search for this name in the output space +of the step (useful for step with multiple outputs). +2. **Name and Version Provided**: If both an artifact name and version are provided, ZenML will use these to identify and attach metadata to the specific artifact version. -2. **Name Only, Within a Step**: If only a name is provided and -`log_metadata` is called within a step, ZenML will try to locate the -corresponding output artifact within the step and attach the metadata to it. If -an output with the provided name does not exist in the step, check scenario 3. -3. **Name Only, Outside a Step**: If only a name is provided and -`log_metadata` is called outside a step, ZenML will attach metadata to the -latest version of the artifact. -4. **Artifact Version ID Provided**: If an artifact version ID is provided +3. **Artifact Version ID Provided**: If an artifact version ID is provided directly, ZenML will use it to fetch and attach the metadata to that specific artifact version. @@ -120,10 +114,11 @@ log_metadata( } }, artifact_name="my_artifact", + artifact_version="version", ) ``` -In the ZenML dashboard, "model_metrics" and "data_details" would appear as +In the ZenML dashboard, `model_metrics` and `data_details` would appear as separate cards, each containing their respective key-value pairs. diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/logging-metadata.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/logging-metadata.md index 63501056a3e..8ea9fc0f1e1 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/logging-metadata.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/logging-metadata.md @@ -23,7 +23,6 @@ log_metadata( }, "processed_data_size": StorageSize(2500000) }, - artifact_name="my_artifact", ) ``` diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 80b27e6af10..eb10f6678c8 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -268,8 +268,8 @@ def log_metadata( if artifact_name is not None: # If a name provided, ensure it is in the outputs assert artifact_name in step_output_names, ( - f"The provided `artifact_name` does not exist in the " - f"step outputs: {step_output_names}." + f"The provided artifact name`{artifact_name}` does not exist " + f"in the step outputs: {step_output_names}." ) else: # If no name provided, ensure there is only one output From 94bb86a6ce022a7af7528a7b64e8757db3de27ea Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Thu, 28 Nov 2024 15:21:28 +0000 Subject: [PATCH 106/124] Auto-update of E2E template --- examples/e2e/.copier-answers.yml | 2 +- examples/e2e/steps/hp_tuning/hp_tuning_single_search.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/e2e/.copier-answers.yml b/examples/e2e/.copier-answers.yml index 38c2abb88a0..e6fb1292beb 100644 --- a/examples/e2e/.copier-answers.yml +++ b/examples/e2e/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.11.20-1-gd8d1576 +_commit: 2024.11.20-2-g760142f _src_path: gh:zenml-io/template-e2e-batch data_quality_checks: true email: info@zenml.io diff --git a/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py b/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py index 7948a011e7b..7b55eebae7a 100644 --- a/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py +++ b/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py @@ -98,6 +98,7 @@ def hp_tuning_single_search( log_metadata( metadata={"metric": float(score)}, artifact_name="hp_result", + infer_artifact=True, ) ### YOUR CODE ENDS HERE ### return cv.best_estimator_ From 4997e3a548ce2f1954a4121df2178616c03c7506 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Thu, 28 Nov 2024 15:24:26 +0000 Subject: [PATCH 107/124] Auto-update of NLP template --- examples/e2e_nlp/.copier-answers.yml | 2 +- examples/e2e_nlp/steps/training/model_trainer.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/e2e_nlp/.copier-answers.yml b/examples/e2e_nlp/.copier-answers.yml index 33820b0a2d2..274927e3ce5 100644 --- a/examples/e2e_nlp/.copier-answers.yml +++ b/examples/e2e_nlp/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.10.30-1-g8d87577 +_commit: 2024.10.30-2-g1ae14e3 _src_path: gh:zenml-io/template-nlp accelerator: cpu cloud_of_choice: aws diff --git a/examples/e2e_nlp/steps/training/model_trainer.py b/examples/e2e_nlp/steps/training/model_trainer.py index 812fe712ee4..0a3de574c09 100644 --- a/examples/e2e_nlp/steps/training/model_trainer.py +++ b/examples/e2e_nlp/steps/training/model_trainer.py @@ -160,6 +160,7 @@ def model_trainer( log_metadata( metadata={"metrics": eval_results}, artifact_name="model", + infer_artifact=True, ) ### YOUR CODE ENDS HERE ### From 2953f064bdc73c16b6f433f5efef6e000f414ac6 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 28 Nov 2024 16:34:48 +0100 Subject: [PATCH 108/124] formatting --- tests/integration/functional/artifacts/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/functional/artifacts/test_utils.py b/tests/integration/functional/artifacts/test_utils.py index c7b4a664f54..79cb52212a8 100644 --- a/tests/integration/functional/artifacts/test_utils.py +++ b/tests/integration/functional/artifacts/test_utils.py @@ -132,7 +132,7 @@ def test_log_metadata_existing(clean_client): log_metadata( metadata={"description_2": "Blupus is great!"}, artifact_name="meaning_of_life", - artifact_version="43" + artifact_version="43", ) log_metadata( metadata={"description_3": "Axl is great!"}, From 88e247c5af091d297c06fe2c19b607d913b12294 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 29 Nov 2024 10:26:10 +0100 Subject: [PATCH 109/124] review comments --- src/zenml/utils/metadata_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index eb10f6678c8..3785f6e8ea5 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -119,12 +119,12 @@ def log_metadata( artifact_version_id: Optional[UUID] = None, artifact_name: Optional[str] = None, artifact_version: Optional[str] = None, - infer_artifact: Optional[bool] = None, + infer_artifact: bool = False, # Models model_version_id: Optional[UUID] = None, model_name: Optional[str] = None, model_version: Optional[Union[ModelStages, int, str]] = None, - infer_model: Optional[bool] = None, + infer_model: bool = False, ) -> None: """Logs metadata for various resource types in a generalized way. From 755c36f5ef82fdb0ce93aea9759c74b6625c243e Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 29 Nov 2024 11:13:05 +0100 Subject: [PATCH 110/124] adding some tests in --- .../integration/functional/utils/__init__.py | 13 ++ .../functional/utils/test_metadata_utils.py | 184 ++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 tests/integration/functional/utils/__init__.py create mode 100644 tests/integration/functional/utils/test_metadata_utils.py diff --git a/tests/integration/functional/utils/__init__.py b/tests/integration/functional/utils/__init__.py new file mode 100644 index 00000000000..cd90a82cfc2 --- /dev/null +++ b/tests/integration/functional/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. diff --git a/tests/integration/functional/utils/test_metadata_utils.py b/tests/integration/functional/utils/test_metadata_utils.py new file mode 100644 index 00000000000..263cc658291 --- /dev/null +++ b/tests/integration/functional/utils/test_metadata_utils.py @@ -0,0 +1,184 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + + +from typing import Annotated, Tuple + +import pytest + +from zenml import Model, log_metadata, pipeline, step + + +@step +def step_multiple_calls() -> None: + """Step calls log_metadata twice, latest value should be returned.""" + log_metadata(metadata={"blupus": 1}) + log_metadata(metadata={"blupus": 2}) + + +@step +def step_single_output() -> Annotated[int, "first"]: + """Step that tests the usage of infer_artifact flag.""" + log_metadata(metadata={"aria": 1}, infer_artifact=True) + log_metadata( + metadata={"aria": 2}, infer_artifact=True, artifact_name="first" + ) + return 1 + + +@step +def step_multiple_outputs() -> ( + Tuple[Annotated[int, "second"], Annotated[int, "third"]] +): + """Step that tests infer_artifact flag with multiple outputs.""" + log_metadata( + metadata={"axl": 1}, infer_artifact=True, artifact_name="second" + ) + return 1, 2 + + +@step +def step_pipeline_model() -> None: + """Step that tests the infer_model flag.""" + log_metadata(metadata={"p": 1}, infer_model=True) + + +@step(model=Model(name="model_name", version="89a")) +def step_step_model() -> None: + """Step that tests the infer_model flag with a custom model version.""" + log_metadata(metadata={"s": 1}, infer_model=True) + + +@pipeline(model=Model(name="model_name", version="a89"), enable_cache=True) +def pipeline_to_log_metadata(): + """Pipeline definition to test the metadata utils.""" + step_multiple_calls() + step_single_output() + step_multiple_outputs() + step_pipeline_model() + step_step_model() + + +def test_metadata_utils(clean_client): + """Testing different functionalities of the log_metadata function.""" + # Run the pipeline + first_run = pipeline_to_log_metadata() + first_steps = first_run.steps + + # Check if the metadata was tagged correctly + assert first_run.run_metadata["step_multiple_calls::blupus"] == 2 + assert first_steps["step_multiple_calls"].run_metadata["blupus"] == 2 + assert ( + first_steps["step_single_output"] + .outputs["first"][0] + .run_metadata["aria"] + == 2 + ) + assert ( + first_steps["step_multiple_outputs"] + .outputs["second"][0] + .run_metadata["axl"] + == 1 + ) + + model_version_s = Model(name="model_name", version="89a") + assert model_version_s.run_metadata["s"] == 1 + + model_version_p = Model(name="model_name", version="a89") + assert model_version_p.run_metadata["p"] == 1 + + # Manually tag the run + log_metadata( + metadata={"manual_run": True}, run_id_name_or_prefix=first_run.id + ) + + # Manually tag the step + log_metadata( + metadata={"manual_step_1": True}, + step_id=first_run.steps["step_multiple_calls"].id, + ) + log_metadata( + metadata={"manual_step_2": True}, + step_name="step_multiple_calls", + run_id_name_or_prefix=first_run.id, + ) + + # Manually tag a model + log_metadata( + metadata={"manual_model_1": True}, model_version_id=model_version_p.id + ) + log_metadata( + metadata={"manual_model_2": True}, + model_name=model_version_p.name, + model_version=model_version_p.version, + ) + + # Manually tag an artifact + log_metadata( + metadata={"manual_artifact_1": True}, + artifact_version_id=first_run.steps["step_single_output"].output.id, + ) + log_metadata( + metadata={"manual_artifact_2": True}, + artifact_name=first_run.steps["step_single_output"].output.name, + artifact_version=first_run.steps["step_single_output"].output.version, + ) + + # Manually tag one step to test the caching logic later + log_metadata( + metadata={"blupus": 3}, + step_id=first_run.steps["step_multiple_calls"].id, + ) + + # Fetch the run and steps again + first_run_fetched = clean_client.get_pipeline_run( + name_id_or_prefix=first_run.id + ) + first_steps_fetched = first_run_fetched.steps + + assert first_run_fetched.run_metadata["manual_run"] + assert first_run_fetched.run_metadata["step_multiple_calls::manual_step_1"] + assert first_run_fetched.run_metadata["step_multiple_calls::manual_step_2"] + assert first_steps_fetched["step_multiple_calls"].run_metadata[ + "manual_step_1" + ] + assert first_steps_fetched["step_multiple_calls"].run_metadata[ + "manual_step_2" + ] + assert first_steps_fetched["step_single_output"].output.run_metadata[ + "manual_artifact_1" + ] + assert first_steps_fetched["step_single_output"].output.run_metadata[ + "manual_artifact_2" + ] + + # Fetch the model again + model_version_p_fetched = Model(name="model_name", version="a89") + + assert model_version_p_fetched.run_metadata["manual_model_1"] + assert model_version_p_fetched.run_metadata["manual_model_2"] + + # Run the cached pipeline + second_run = pipeline_to_log_metadata() + assert second_run.steps["step_multiple_calls"].run_metadata["blupus"] == 2 + + # Test some of the invalid usages + with pytest.raises(ValueError): + log_metadata(metadata={"auto_step_1": True}) + + with pytest.raises(ValueError): + log_metadata(metadata={"auto_model_1": True}, infer_model=True) + + with pytest.raises(ValueError): + log_metadata(metadata={"auto_artifact_1": True}, infer_artifact=True) From c1b5a4ccb5b5b63b4eacb406c88638b1bf2a5272 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 29 Nov 2024 13:03:58 +0100 Subject: [PATCH 111/124] review comments --- .../zen_stores/schemas/pipeline_run_schemas.py | 4 +--- src/zenml/zen_stores/schemas/step_run_schemas.py | 4 +--- tests/integration/functional/utils/__init__.py | 13 ------------- 3 files changed, 2 insertions(+), 19 deletions(-) delete mode 100644 tests/integration/functional/utils/__init__.py diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index e4d66748fdd..091017b6ef2 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -297,8 +297,6 @@ def to_model( else {} ) - run_metadata = self.fetch_metadata() - if self.deployment is not None: deployment = self.deployment.to_model() @@ -375,7 +373,7 @@ def to_model( } metadata = PipelineRunResponseMetadata( workspace=self.workspace.to_model(), - run_metadata=run_metadata, + run_metadata=self.fetch_metadata(), config=config, steps=steps, start_time=self.start_time, diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index d7f13745312..75736c6d54c 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -228,8 +228,6 @@ def to_model( RuntimeError: If the step run schema does not have a deployment_id or a step_configuration. """ - run_metadata = self.fetch_metadata() - input_artifacts = { artifact.name: StepRunInputResponse( input_type=StepRunInputArtifactType(artifact.type), @@ -316,7 +314,7 @@ def to_model( pipeline_run_id=self.pipeline_run_id, original_step_run_id=self.original_step_run_id, parent_step_ids=[p.parent_id for p in self.parents], - run_metadata=run_metadata, + run_metadata=self.fetch_metadata(), ) resources = None diff --git a/tests/integration/functional/utils/__init__.py b/tests/integration/functional/utils/__init__.py deleted file mode 100644 index cd90a82cfc2..00000000000 --- a/tests/integration/functional/utils/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) ZenML GmbH 2024. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. From 93dccd852e4f2791bde63671a51e267ce1d213ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bar=C4=B1=C5=9F=20Can=20Durak?= <36421093+bcdurak@users.noreply.github.com> Date: Fri, 29 Nov 2024 13:04:17 +0100 Subject: [PATCH 112/124] Update src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py Co-authored-by: Michael Schuster --- .../migrations/versions/cc269488e5a9_separate_run_metadata.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py index ca254308ea6..e211022b52a 100644 --- a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py +++ b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py @@ -86,7 +86,6 @@ def upgrade() -> None: op.drop_column("run_metadata", "resource_id") op.drop_column("run_metadata", "resource_type") - # Add the cached column to the database table op.add_column( "run_metadata", sa.Column( From 6bc0002081c4ebd29983149e92ac47a071348f65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bar=C4=B1=C5=9F=20Can=20Durak?= <36421093+bcdurak@users.noreply.github.com> Date: Fri, 29 Nov 2024 13:04:25 +0100 Subject: [PATCH 113/124] Update src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py Co-authored-by: Michael Schuster --- .../migrations/versions/cc269488e5a9_separate_run_metadata.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py index e211022b52a..8ef4e1b78f7 100644 --- a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py +++ b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py @@ -42,7 +42,6 @@ def upgrade() -> None: # Migrate existing data from `run_metadata` to `run_metadata_resource` connection = op.get_bind() - # Fetch existing `run_metadata` data run_metadata_data = connection.execute( sa.text(""" SELECT id, resource_id, resource_type From 852d99da7edc9f3ab1ed0cbfd631382790180f1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bar=C4=B1=C5=9F=20Can=20Durak?= <36421093+bcdurak@users.noreply.github.com> Date: Fri, 29 Nov 2024 13:04:39 +0100 Subject: [PATCH 114/124] Update src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py Co-authored-by: Michael Schuster --- .../migrations/versions/cc269488e5a9_separate_run_metadata.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py index 8ef4e1b78f7..59a4106fa5c 100644 --- a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py +++ b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py @@ -20,7 +20,6 @@ def upgrade() -> None: """Creates the 'run_metadata_resource' table and migrates data.""" - # Create the `run_metadata_resource` table op.create_table( "run_metadata_resource", sa.Column( From e14283648254557f36bec6b46c6e19457c095dd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bar=C4=B1=C5=9F=20Can=20Durak?= <36421093+bcdurak@users.noreply.github.com> Date: Fri, 29 Nov 2024 13:04:49 +0100 Subject: [PATCH 115/124] Update src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py Co-authored-by: Michael Schuster --- .../migrations/versions/cc269488e5a9_separate_run_metadata.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py index 59a4106fa5c..1616ddd5aa7 100644 --- a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py +++ b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py @@ -38,7 +38,6 @@ def upgrade() -> None: ), ) - # Migrate existing data from `run_metadata` to `run_metadata_resource` connection = op.get_bind() run_metadata_data = connection.execute( From ded49db47171cd99fff6a1639309bf680d5ab10e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bar=C4=B1=C5=9F=20Can=20Durak?= <36421093+bcdurak@users.noreply.github.com> Date: Fri, 29 Nov 2024 13:04:58 +0100 Subject: [PATCH 116/124] Update src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py Co-authored-by: Michael Schuster --- .../migrations/versions/cc269488e5a9_separate_run_metadata.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py index 1616ddd5aa7..e0b106ba824 100644 --- a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py +++ b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py @@ -79,7 +79,6 @@ def upgrade() -> None: resource_data, ) - # Drop the old `resource_id` and `resource_type` columns from `run_metadata` op.drop_column("run_metadata", "resource_id") op.drop_column("run_metadata", "resource_type") From bedb364395582943548a717a0806eb587843c1ca Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 29 Nov 2024 13:09:41 +0100 Subject: [PATCH 117/124] changed assert to value error --- src/zenml/utils/metadata_utils.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 3785f6e8ea5..72b3e791c4a 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -267,16 +267,22 @@ def log_metadata( if artifact_name is not None: # If a name provided, ensure it is in the outputs - assert artifact_name in step_output_names, ( - f"The provided artifact name`{artifact_name}` does not exist " - f"in the step outputs: {step_output_names}." - ) + if artifact_name not in step_output_names: + raise ValueError( + f"The provided artifact name`{artifact_name}` does not " + f"exist in the step outputs: {step_output_names}." + ) else: # If no name provided, ensure there is only one output - assert len(step_output_names) == 1, ( - "There is mode than one output. If you would like to use the " - "`infer_artifact` option, you need to define an artifact_name." - ) + if len(step_output_names) > 1: + raise ValueError( + "There is more than one output. If you would like to use " + "the `infer_artifact` option, you need to define an " + "`artifact_name`." + ) + + if len(step_output_names) == 0: + raise ValueError("The step does not have any outputs.") artifact_name = step_output_names[0] From dcc8aa912d716c7100f1fa35638de29979dcf5a7 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 29 Nov 2024 13:12:07 +0100 Subject: [PATCH 118/124] fixed the alembic head --- .../migrations/versions/cc269488e5a9_separate_run_metadata.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py index e0b106ba824..52a4cbd8ef2 100644 --- a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py +++ b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py @@ -1,7 +1,7 @@ """Separate run metadata into resource link table with new UUIDs. Revision ID: cc269488e5a9 -Revises: ec6307720f92 +Revises: b73bc71f1106 Create Date: 2024-11-12 09:46:46.587478 """ @@ -13,7 +13,7 @@ # revision identifiers, used by Alembic. revision = "cc269488e5a9" -down_revision = "ec6307720f92" +down_revision = "b73bc71f1106" branch_labels = None depends_on = None From 772f6392d5c2863a877f5db0895b97bf5756b057 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 29 Nov 2024 13:16:02 +0100 Subject: [PATCH 119/124] changed the interaction with the models --- src/zenml/utils/metadata_utils.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 72b3e791c4a..22a756b8c47 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -193,10 +193,19 @@ def log_metadata( # Log metadata to a model version by name and version elif model_name is not None and model_version is not None: - from zenml import Model - - mv = Model(name=model_name, version=model_version) - mv.log_metadata(metadata) + model_version_model = client.get_model_version( + model_name_or_id=model_name, + model_version_name_or_number_or_id=model_version, + ) + client.create_run_metadata( + metadata=metadata, + resources=[ + RunMetadataResource( + id=model_version_model.id, + type=MetadataResourceTypes.MODEL_VERSION, + ) + ], + ) # Log metadata to a model version by id elif model_version_id is not None: @@ -221,8 +230,15 @@ def log_metadata( "Otherwise, you can provide a `model_version_id` or a " "combination of `model_name` and `model_version`." ) - mv = step_context.model - mv.log_metadata(metadata) + client.create_run_metadata( + metadata=metadata, + resources=[ + RunMetadataResource( + id=step_context.model_version.id, + type=MetadataResourceTypes.MODEL_VERSION, + ) + ], + ) # Log metadata to an artifact version by its name and version elif artifact_name is not None and artifact_version is not None: From d13a9090f012182ec6c7a4591ef0c0ee87d0378e Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 29 Nov 2024 13:33:06 +0100 Subject: [PATCH 120/124] trimmed down --- src/zenml/utils/metadata_utils.py | 144 +++++++++++++----------------- 1 file changed, 64 insertions(+), 80 deletions(-) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 22a756b8c47..0362fca4325 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Utility functions to handle metadata for ZenML entities.""" -from typing import Dict, Optional, Union, overload +from typing import Dict, List, Optional, Union, overload from uuid import UUID from zenml.client import Client @@ -150,16 +150,16 @@ def log_metadata( """ client = Client() + resources: List[RunMetadataResource] = [] + publisher_step_id = None + # Log metadata to a step by ID if step_id is not None: - client.create_run_metadata( - metadata=metadata, - resources=[ - RunMetadataResource( - id=step_id, type=MetadataResourceTypes.STEP_RUN - ) - ], - ) + resources = [ + RunMetadataResource( + id=step_id, type=MetadataResourceTypes.STEP_RUN + ) + ] # Log metadata to a step by name and run ID elif step_name is not None and run_id_name_or_prefix is not None: @@ -168,28 +168,22 @@ def log_metadata( .steps[step_name] .id ) - client.create_run_metadata( - metadata=metadata, - resources=[ - RunMetadataResource( - id=step_model_id, type=MetadataResourceTypes.STEP_RUN - ) - ], - ) + resources = [ + RunMetadataResource( + id=step_model_id, type=MetadataResourceTypes.STEP_RUN + ) + ] # Log metadata to a run by ID elif run_id_name_or_prefix is not None: run_model = client.get_pipeline_run( name_id_or_prefix=run_id_name_or_prefix ) - client.create_run_metadata( - metadata=metadata, - resources=[ - RunMetadataResource( - id=run_model.id, type=MetadataResourceTypes.PIPELINE_RUN - ) - ], - ) + resources = [ + RunMetadataResource( + id=run_model.id, type=MetadataResourceTypes.PIPELINE_RUN + ) + ] # Log metadata to a model version by name and version elif model_name is not None and model_version is not None: @@ -197,27 +191,21 @@ def log_metadata( model_name_or_id=model_name, model_version_name_or_number_or_id=model_version, ) - client.create_run_metadata( - metadata=metadata, - resources=[ - RunMetadataResource( - id=model_version_model.id, - type=MetadataResourceTypes.MODEL_VERSION, - ) - ], - ) + resources = [ + RunMetadataResource( + id=model_version_model.id, + type=MetadataResourceTypes.MODEL_VERSION, + ) + ] # Log metadata to a model version by id elif model_version_id is not None: - client.create_run_metadata( - metadata=metadata, - resources=[ - RunMetadataResource( - id=model_version_id, - type=MetadataResourceTypes.MODEL_VERSION, - ) - ], - ) + resources = [ + RunMetadataResource( + id=model_version_id, + type=MetadataResourceTypes.MODEL_VERSION, + ) + ] # Log metadata to a model through the step context elif infer_model is True: @@ -230,42 +218,33 @@ def log_metadata( "Otherwise, you can provide a `model_version_id` or a " "combination of `model_name` and `model_version`." ) - client.create_run_metadata( - metadata=metadata, - resources=[ - RunMetadataResource( - id=step_context.model_version.id, - type=MetadataResourceTypes.MODEL_VERSION, - ) - ], - ) + resources = [ + RunMetadataResource( + id=step_context.model_version.id, + type=MetadataResourceTypes.MODEL_VERSION, + ) + ] # Log metadata to an artifact version by its name and version elif artifact_name is not None and artifact_version is not None: artifact_version_model = client.get_artifact_version( name_id_or_prefix=artifact_name, version=artifact_version ) - client.create_run_metadata( - metadata=metadata, - resources=[ - RunMetadataResource( - id=artifact_version_model.id, - type=MetadataResourceTypes.ARTIFACT_VERSION, - ) - ], - ) + resources = [ + RunMetadataResource( + id=artifact_version_model.id, + type=MetadataResourceTypes.ARTIFACT_VERSION, + ) + ] # Log metadata to an artifact version by its ID elif artifact_version_id is not None: - client.create_run_metadata( - metadata=metadata, - resources=[ - RunMetadataResource( - id=artifact_version_id, - type=MetadataResourceTypes.ARTIFACT_VERSION, - ) - ], - ) + resources = [ + RunMetadataResource( + id=artifact_version_id, + type=MetadataResourceTypes.ARTIFACT_VERSION, + ) + ] # Log metadata to an artifact version through the step context elif infer_artifact is True: @@ -305,6 +284,7 @@ def log_metadata( step_context.add_output_metadata( metadata=metadata, output_name=artifact_name ) + return # If every additional value is None, that means we are calling it bare bones # and this call needs to happen during a step execution. We will use the @@ -333,16 +313,14 @@ def log_metadata( "of the step execution, please provide the required " "identifiers." ) - client.create_run_metadata( - metadata=metadata, - resources=[ - RunMetadataResource( - id=step_context.step_run.id, - type=MetadataResourceTypes.STEP_RUN, - ) - ], - publisher_step_id=step_context.step_run.id, - ) + + resources = [ + RunMetadataResource( + id=step_context.step_run.id, + type=MetadataResourceTypes.STEP_RUN, + ) + ] + publisher_step_id = (step_context.step_run.id,) else: raise ValueError( @@ -376,3 +354,9 @@ def log_metadata( log_metadata(metadata={}, artifact_version_id=...) """ ) + + client.create_run_metadata( + metadata=metadata, + resources=resources, + publisher_step_id=publisher_step_id, + ) From 169b4c6d523fc6150cff87799657da99e06d6613 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 29 Nov 2024 13:35:13 +0100 Subject: [PATCH 121/124] small bugfix --- src/zenml/utils/metadata_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 0362fca4325..24aea003cea 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -320,7 +320,7 @@ def log_metadata( type=MetadataResourceTypes.STEP_RUN, ) ] - publisher_step_id = (step_context.step_run.id,) + publisher_step_id = step_context.step_run.id else: raise ValueError( From ba131c0bed3a97ff0679d0a6f7d557a6bcd32800 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 29 Nov 2024 13:37:59 +0100 Subject: [PATCH 122/124] naming recommendations --- src/zenml/zen_stores/schemas/artifact_schemas.py | 2 +- src/zenml/zen_stores/schemas/model_schemas.py | 2 +- .../zen_stores/schemas/pipeline_run_schemas.py | 2 +- .../zen_stores/schemas/run_metadata_schemas.py | 16 ++++++++-------- src/zenml/zen_stores/schemas/step_run_schemas.py | 2 +- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index 15d448c92bb..02e842a5fb5 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -245,7 +245,7 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True): back_populates="artifact_versions" ) run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship( - back_populates="artifact_version", + back_populates="artifact_versions", sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)", cascade="delete", diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 7e67c1cf2b1..feb4a93dc80 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -304,7 +304,7 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): stage: str = Field(sa_column=Column(TEXT, nullable=True)) run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship( - back_populates="model_version", + back_populates="model_versions", sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)", cascade="delete", diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index db127b1b1ff..d0af218b629 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -141,7 +141,7 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): workspace: "WorkspaceSchema" = Relationship(back_populates="runs") user: Optional["UserSchema"] = Relationship(back_populates="runs") run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship( - back_populates="pipeline_run", + back_populates="pipeline_runs", sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)", cascade="delete", diff --git a/src/zenml/zen_stores/schemas/run_metadata_schemas.py b/src/zenml/zen_stores/schemas/run_metadata_schemas.py index b8945cf60e6..f4465b13e66 100644 --- a/src/zenml/zen_stores/schemas/run_metadata_schemas.py +++ b/src/zenml/zen_stores/schemas/run_metadata_schemas.py @@ -110,31 +110,31 @@ class RunMetadataResourceSchema(SQLModel, table=True): run_metadata: RunMetadataSchema = Relationship(back_populates="resources") # Relationship to link specific resource types - pipeline_run: List["PipelineRunSchema"] = Relationship( + pipeline_runs: List["PipelineRunSchema"] = Relationship( back_populates="run_metadata_resources", sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)", - overlaps="run_metadata_resources,step_run,artifact_version,model_version", + overlaps="run_metadata_resources,step_runs,artifact_versions,model_versions", ), ) - step_run: List["StepRunSchema"] = Relationship( + step_runs: List["StepRunSchema"] = Relationship( back_populates="run_metadata_resources", sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)", - overlaps="run_metadata_resources,pipeline_run,artifact_version,model_version", + overlaps="run_metadata_resources,pipeline_runs,artifact_versions,model_versions", ), ) - artifact_version: List["ArtifactVersionSchema"] = Relationship( + artifact_versions: List["ArtifactVersionSchema"] = Relationship( back_populates="run_metadata_resources", sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)", - overlaps="run_metadata_resources,pipeline_run,step_run,model_version", + overlaps="run_metadata_resources,pipeline_runs,step_runs,model_versions", ), ) - model_version: List["ModelVersionSchema"] = Relationship( + model_versions: List["ModelVersionSchema"] = Relationship( back_populates="run_metadata_resources", sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)", - overlaps="run_metadata_resources,pipeline_run,step_run,artifact_version", + overlaps="run_metadata_resources,pipeline_runs,step_runs,artifact_versions", ), ) diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index 75736c6d54c..f8788505156 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -144,7 +144,7 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): back_populates="step_runs" ) run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship( - back_populates="step_run", + back_populates="step_runs", sa_relationship_kwargs=dict( primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)", cascade="delete", From e05fb982d50d4c761213ce4d71e40759d6fa3989 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 29 Nov 2024 13:41:16 +0100 Subject: [PATCH 123/124] linting --- src/zenml/utils/metadata_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 24aea003cea..2b4e641f039 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -218,6 +218,12 @@ def log_metadata( "Otherwise, you can provide a `model_version_id` or a " "combination of `model_name` and `model_version`." ) + + if step_context.model_version is None: + raise ValueError( + "The step context does not feature any model versions." + ) + resources = [ RunMetadataResource( id=step_context.model_version.id, From d662ac36aeb5ee8c0fd738586377cb3fbe37698a Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 29 Nov 2024 15:18:29 +0100 Subject: [PATCH 124/124] fixing the test --- tests/integration/functional/test_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index 9afbb3b9b14..72557777ab3 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -1152,14 +1152,14 @@ def dummy(): artifact_name="preexisting", artifact_version="1.2.3", ) + with pytest.raises(KeyError): + clean_client.get_artifact_version("new_one") + dummy() log_metadata( metadata={"some_meta": "meta_preexisting"}, model_name="aria", model_version="model_version", ) - with pytest.raises(KeyError): - clean_client.get_artifact_version("new_one") - dummy() class TestModel: