Skip to content

Commit

Permalink
Fix input resolution for steps with dynamic artifact names (#3228)
Browse files Browse the repository at this point in the history
* Fix input resolution for steps with dynamic artifact names

* Improve logic

* Linting

* Add test

* Fix variable access

* Really fix test

* Rename
  • Loading branch information
schustmi authored Nov 29, 2024
1 parent ee48d1a commit 0ccb1fd
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 15 deletions.
9 changes: 9 additions & 0 deletions src/zenml/models/v2/core/pipeline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,15 @@ def is_templatable(self) -> bool:
"""
return self.get_metadata().is_templatable

@property
def step_substitutions(self) -> Dict[str, Dict[str, str]]:
"""The `step_substitutions` property.
Returns:
the value of the property.
"""
return self.get_metadata().step_substitutions

@property
def model_version(self) -> Optional[ModelVersionResponse]:
"""The `model_version` property.
Expand Down
25 changes: 19 additions & 6 deletions src/zenml/orchestrators/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from zenml.config.step_configurations import Step
from zenml.enums import ArtifactSaveType, StepRunInputArtifactType
from zenml.exceptions import InputResolutionError
from zenml.utils import pagination_utils
from zenml.utils import pagination_utils, string_utils

if TYPE_CHECKING:
from zenml.models import PipelineRunResponse
Expand Down Expand Up @@ -53,7 +53,8 @@ def resolve_step_inputs(
current_run_steps = {
run_step.name: run_step
for run_step in pagination_utils.depaginate(
Client().list_run_steps, pipeline_run_id=pipeline_run.id
Client().list_run_steps,
pipeline_run_id=pipeline_run.id,
)
}

Expand All @@ -66,11 +67,23 @@ def resolve_step_inputs(
f"No step `{input_.step_name}` found in current run."
)

# Try to get the substitutions from the pipeline run first, as we
# already have a hydrated version of that. In the unlikely case
# that the pipeline run is outdated, we fetch it from the step
# run instead which will costs us one hydration call.
substitutions = (
pipeline_run.step_substitutions.get(step_run.name)
or step_run.config.substitutions
)
output_name = string_utils.format_name_template(
input_.output_name, substitutions=substitutions
)

try:
outputs = step_run.outputs[input_.output_name]
outputs = step_run.outputs[output_name]
except KeyError:
raise InputResolutionError(
f"No step output `{input_.output_name}` found for step "
f"No step output `{output_name}` found for step "
f"`{input_.step_name}`."
)

Expand All @@ -83,12 +96,12 @@ def resolve_step_inputs(
# This should never happen, there can only be a single regular step
# output for a name
raise InputResolutionError(
f"Too many step outputs for output `{input_.output_name}` of "
f"Too many step outputs for output `{output_name}` of "
f"step `{input_.step_name}`."
)
elif len(step_outputs) == 0:
raise InputResolutionError(
f"No step output `{input_.output_name}` found for step "
f"No step output `{output_name}` found for step "
f"`{input_.step_name}`."
)

Expand Down
3 changes: 3 additions & 0 deletions src/zenml/orchestrators/step_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ def create_cached_step_runs(
for invocation_id in cache_candidates:
visited_invocations.add(invocation_id)

# Make sure the request factory has the most up to date pipeline
# run to avoid hydration calls
request_factory.pipeline_run = pipeline_run
try:
step_run_request = request_factory.create_request(
invocation_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

from contextlib import ExitStack as does_not_raise
from typing import Callable, Tuple

import pytest
Expand Down Expand Up @@ -122,6 +123,11 @@ def mixed_with_unannotated_returns() -> (
)


@step
def step_with_string_input(input_: str) -> None:
pass


@pytest.mark.parametrize(
"step",
[
Expand Down Expand Up @@ -362,3 +368,17 @@ def _inner(pass_to_step: str = ""):
assert p2_step_subs["date"] == "step_level"
assert p1_step_subs["funny_name"] == "pipeline_level"
assert p2_step_subs["funny_name"] == "step_level"


def test_dynamically_named_artifacts_in_downstream_steps(
clean_client: "Client",
):
"""Test that dynamically named artifacts can be used in downstream steps."""

@pipeline(enable_cache=False)
def _inner(ret: str):
artifact = dynamic_single_string_standard()
step_with_string_input(artifact)

with does_not_raise():
_inner("output_1")
13 changes: 13 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
from collections import defaultdict
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
from uuid import uuid4
Expand Down Expand Up @@ -416,6 +417,12 @@ def sample_pipeline_run(
sample_workspace_model: WorkspaceResponse,
) -> PipelineRunResponse:
"""Return sample pipeline run view for testing purposes."""
now = datetime.utcnow()
substitutions = {
"date": now.strftime("%Y_%m_%d"),
"time": now.strftime("%H_%M_%S_%f"),
}

return PipelineRunResponse(
id=uuid4(),
name="sample_run_name",
Expand All @@ -430,6 +437,7 @@ def sample_pipeline_run(
workspace=sample_workspace_model,
config=PipelineConfiguration(name="aria_pipeline"),
is_templatable=False,
steps_substitutions=defaultdict(lambda: substitutions.copy()),
),
resources=PipelineRunResponseResources(tags=[]),
)
Expand Down Expand Up @@ -543,10 +551,15 @@ def f(
spec = StepSpec.model_validate(
{"source": "module.step_class", "upstream_steps": []}
)
now = datetime.utcnow()
config = StepConfiguration.model_validate(
{
"name": step_name,
"outputs": outputs or {},
"substitutions": {
"date": now.strftime("%Y_%m_%d"),
"time": now.strftime("%H_%M_%S_%f"),
},
}
)
return StepRunResponse(
Expand Down
20 changes: 11 additions & 9 deletions tests/unit/orchestrators/test_input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

from uuid import uuid4

import pytest

from zenml.config.step_configurations import Step
from zenml.enums import StepRunInputArtifactType
from zenml.exceptions import InputResolutionError
from zenml.models import Page, PipelineRunResponse
from zenml.models import Page
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
from zenml.models.v2.core.step_run import StepRunInputResponse
from zenml.orchestrators import input_utils
Expand All @@ -29,6 +28,7 @@ def test_input_resolution(
mocker,
sample_artifact_version_model: ArtifactVersionResponse,
create_step_run,
sample_pipeline_run,
):
"""Tests that input resolution works if the correct models exist in the
zen store."""
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_input_resolution(
)

input_artifacts, parent_ids = input_utils.resolve_step_inputs(
step=step, pipeline_run=PipelineRunResponse(id=uuid4(), name="foo")
step=step, pipeline_run=sample_pipeline_run
)
assert input_artifacts == {
"input_name": StepRunInputResponse(
Expand All @@ -71,7 +71,7 @@ def test_input_resolution(
assert parent_ids == [step_run.id]


def test_input_resolution_with_missing_step_run(mocker):
def test_input_resolution_with_missing_step_run(mocker, sample_pipeline_run):
"""Tests that input resolution fails if the upstream step run is missing."""
mocker.patch(
"zenml.zen_stores.sql_zen_store.SqlZenStore.list_run_steps",
Expand All @@ -97,11 +97,13 @@ def test_input_resolution_with_missing_step_run(mocker):

with pytest.raises(InputResolutionError):
input_utils.resolve_step_inputs(
step=step, pipeline_run=PipelineRunResponse(id=uuid4(), name="foo")
step=step, pipeline_run=sample_pipeline_run
)


def test_input_resolution_with_missing_artifact(mocker, create_step_run):
def test_input_resolution_with_missing_artifact(
mocker, create_step_run, sample_pipeline_run
):
"""Tests that input resolution fails if the upstream step run output
artifact is missing."""
step_run = create_step_run(
Expand Down Expand Up @@ -132,12 +134,12 @@ def test_input_resolution_with_missing_artifact(mocker, create_step_run):

with pytest.raises(InputResolutionError):
input_utils.resolve_step_inputs(
step=step, pipeline_run=PipelineRunResponse(id=uuid4(), name="foo")
step=step, pipeline_run=sample_pipeline_run
)


def test_input_resolution_fetches_all_run_steps(
mocker, sample_artifact_version_model, create_step_run
mocker, sample_artifact_version_model, create_step_run, sample_pipeline_run
):
"""Tests that input resolution fetches all step runs of the pipeline run."""
step_run = create_step_run(
Expand Down Expand Up @@ -178,7 +180,7 @@ def test_input_resolution_fetches_all_run_steps(
)

input_utils.resolve_step_inputs(
step=step, pipeline_run=PipelineRunResponse(id=uuid4(), name="foo")
step=step, pipeline_run=sample_pipeline_run
)

# `resolve_step_inputs(...)` depaginates the run steps so we fetch all
Expand Down

0 comments on commit 0ccb1fd

Please sign in to comment.