Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom dataset that registers Vertex AI Artifacts #177

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ didn't alter the remote pipeline execution, and only escaped the local Python pr
with the proper remote pipeline execution handling, and possibly per-task timeout enabled by [the new kfp feature](https://github.com/kubeflow/pipelines/pull/10481).
- Assign pipelines to Vertex AI experiments
- Migrated `pydantic` library to v2
- Custom dataset that creates Vertex AI artifact

## [0.11.1] - 2024-07-01

Expand Down
6 changes: 6 additions & 0 deletions docs/source/03_getting_started/01_quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,9 @@ As you can see, the pipeline was compiled and started in Vertex AI Pipelines. Wh
![Kedro pipeline running in Vertex AI Pipelines](vertexai_running_pipeline.gif)


## Log datasets to Vertex AI Metadata

The plugin implements custom `kedro_vertexai.vertex_ai.datasets.KedroVertexAIMetadataDataset` dataset that creates an Vertex AI Artifact.
It allows to specify any Kedro dataset in the `base_dataset` argument, and it uses its `_save` and `_load` methods for the io.
The base dataset arguments are passed in `base_dataset_args` argument as dictionary. The created artifact is associated with Vertex AI run id and job name as metadata, and additional metadata can be specified in the `metadata` argument.
The `dispaly_name` and `schema` arguments are used for the artifact creation, please reference [Vertex AI docs](https://cloud.google.com/vertex-ai/docs/ml-metadata/tracking#create-artifact) to learn more about them.
9 changes: 8 additions & 1 deletion kedro_vertexai/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, config, project_name, context, run_name: str):
self.project_name = project_name
self.context: KedroContext = context
self.run_config: RunConfig = config.run_config
self.catalog = context.config_loader.get("catalog*")
self.catalog = context.config_loader.get("catalog")
self.grouping: NodeGrouper = dynamic_init_class(
self.run_config.grouping.cls,
context,
Expand Down Expand Up @@ -167,6 +167,7 @@ def _build_kfp_tasks(
"MLFLOW_RUN_ID=\"{{$.inputs.parameters['mlflow_run_id']}}\" "
if is_mlflow_enabled()
else "",
self._generate_gcp_env_vars_command(),
kedro_command,
]
).strip()
Expand Down Expand Up @@ -206,6 +207,12 @@ def _generate_params_command(self, should_add_params) -> str:
else ""
)

def _generate_gcp_env_vars_command(self) -> str:
vertex_conf = self.context.config_loader.get("vertexai")
project_id = vertex_conf.get("project_id")
region = vertex_conf.get("region")
return f"GCP_PROJECT_ID={project_id} GCP_REGION={region}"

def _configure_resources(self, name: str, tags: set, task: PipelineTask):
resources = self.run_config.resources_for(name, tags)
node_selectors = self.run_config.node_selectors_for(name, tags)
Expand Down
79 changes: 79 additions & 0 deletions kedro_vertexai/vertex_ai/datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import bz2
import os
from functools import lru_cache
from sys import version_info
from typing import Any, Dict

import cloudpickle
import fsspec
from google.cloud import aiplatform as aip
from kedro.io import AbstractDataset, MemoryDataset

from kedro_vertexai.config import dynamic_load_class
from kedro_vertexai.constants import KEDRO_VERTEXAI_BLOB_TEMP_DIR_NAME


Expand Down Expand Up @@ -58,3 +61,79 @@ def __getattribute__(self, __name: None) -> Any:
if __name == "__class__":
return MemoryDataset.__getattribute__(MemoryDataset(), __name)
return super().__getattribute__(__name)


class KedroVertexAIMetadataDataset(AbstractDataset):
def __init__(
self,
base_dataset: str,
display_name: str,
base_dataset_args: Dict[str, Any],
metadata: Dict[str, Any],
schema: str = "system.Dataset",
) -> None:
base_dataset_class: AbstractDataset = dynamic_load_class(base_dataset)

self._base_dataset: AbstractDataset = base_dataset_class(**base_dataset_args)
self._display_name = display_name
self._artifact_uri = (
f"{self._base_dataset._protocol}://{self._base_dataset._get_save_path()}"
)
self._artifact_schema = schema

try:
project_id = os.environ["GCP_PROJECT_ID"]
region = os.environ["GCP_REGION"]
except KeyError as e:
self._logger.error(
"""Did you set GCP_PROJECT_ID and GCP_REGION env variables?
They must be set in order to create Vertex AI artifact."""
)
raise e

aip.init(
project=project_id,
location=region,
)

self._run_id = os.environ.get("KEDRO_CONFIG_RUN_ID")
self._job_name = os.environ.get("KEDRO_CONFIG_JOB_NAME")

if self._run_id is None or self._job_name is None:
self._logger.warning(
"""KEDRO_CONFIG_RUN_ID and PIPELINE_JOB_NAME_PLACEHOLDER env variables are not set.
Set them to assign it as artifact metadata."""
)

self._metadata = metadata

super().__init__()

def _load(self) -> Any:
return self._base_dataset._load()

def _save(self, data: Any) -> None:
self._base_dataset._save(data)

self._logger.info(
f"Creating {self._display_name} artifact with uri {self._artifact_uri}"
)

aip.Artifact.create(
schema_title=self._artifact_schema,
display_name=self._display_name,
uri=self._artifact_uri,
metadata={
"pipeline run id": self._run_id,
"pipeline job name": self._job_name,
**self._metadata,
},
)

def _describe(self) -> Dict[str, Any]:
return {
"info": "for use only within Kedro Vertex AI Pipelines",
"display_name": self._display_name,
"artifact_uri": self._artifact_uri,
"base_dataset": self._base_dataset.__class__.__name__,
}
69 changes: 69 additions & 0 deletions tests/test_metadata_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch

from kedro_vertexai.vertex_ai.datasets import KedroVertexAIMetadataDataset


class TestKedroVertexAIMetadataDataset(unittest.TestCase):
def test_dataset(self):
with patch("kedro_vertexai.vertex_ai.datasets.aip.init"), patch(
"kedro_vertexai.vertex_ai.datasets.aip.Artifact.create"
) as aip_artifact_create_mock, patch(
"kedro_vertexai.vertex_ai.datasets.dynamic_load_class"
) as mock_dynamic_load_class:
dataset_class_mock = mock_dynamic_load_class.return_value

mock_dynamic_load_class.return_value.return_value._protocol = "gcs"
mock_dynamic_load_class.return_value.return_value._get_save_path.return_value = Path(
"save_path/file.csv"
)
mock_dynamic_load_class.return_value.return_value.__class__.__name__ == "some_package.SomeDataset"

os.environ["GCP_PROJECT_ID"] = "project id"
os.environ["GCP_REGION"] = "region"

dataset = KedroVertexAIMetadataDataset(
base_dataset="some_package.SomeDataset",
display_name="dataset_name",
base_dataset_args={"some_argument": "its_value"},
metadata={"test_key": "Some additional info"},
)

mock_dynamic_load_class.assert_called_once()
assert len(mock_dynamic_load_class.call_args.args)
assert (
mock_dynamic_load_class.call_args.args[0] == "some_package.SomeDataset"
)

dataset_class_mock.assert_called_once()
assert "some_argument" in dataset_class_mock.call_args.kwargs
assert dataset_class_mock.call_args.kwargs["some_argument"] == "its_value"

assert dataset._artifact_uri == "gcs://save_path/file.csv"

data_mock = MagicMock()
dataset.save(data_mock)

aip_artifact_create_mock.assert_called_once()
assert (
aip_artifact_create_mock.call_args.kwargs["schema_title"]
== "system.Dataset"
)
assert (
aip_artifact_create_mock.call_args.kwargs["display_name"]
== "dataset_name"
)
assert (
aip_artifact_create_mock.call_args.kwargs["uri"]
== "gcs://save_path/file.csv"
)
assert (
aip_artifact_create_mock.call_args.kwargs["metadata"]["test_key"]
== "Some additional info"
)

dataset_info = dataset._describe()
assert dataset_info["display_name"] == "dataset_name"
assert dataset_info["artifact_uri"] == "gcs://save_path/file.csv"
Loading