Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make dicts/lists visualizable and add JSON as viz type #2882

Merged
merged 10 commits into from
Nov 28, 2024
1 change: 1 addition & 0 deletions src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class VisualizationType(StrEnum):
HTML = "html"
IMAGE = "image"
MARKDOWN = "markdown"
JSON = "json"


class ZenMLServiceType(StrEnum):
Expand Down
21 changes: 20 additions & 1 deletion src/zenml/materializers/built_in_materializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)

from zenml.artifact_stores.base_artifact_store import BaseArtifactStore
from zenml.enums import ArtifactType
from zenml.enums import ArtifactType, VisualizationType
from zenml.logger import get_logger
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.materializers.materializer_registry import materializer_registry
Expand Down Expand Up @@ -414,6 +414,25 @@ def save(self, data: Any) -> None:
for entry in metadata:
self.artifact_store.rmtree(entry["path"])
raise e

# save dict type objects to JSON file with JSON visualization type
def save_visualizations(
self, data: Any
) -> Dict[str, "VisualizationType"]:
"""Save visualizations for the given data.

Args:
data: The data to save visualizations for.

Returns:
A dictionary of visualization URIs and their types.
"""
# dict type objects are always saved as JSON files
# doesn't work for non-serializable dict types as they
# are saved as list of lists in different files
if isinstance(data, dict) and _is_serializable(data):
wjayesh marked this conversation as resolved.
Show resolved Hide resolved
return {self.data_path: VisualizationType.JSON}
return {}

def extract_metadata(self, data: Any) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given built-in container object.
Expand Down
11 changes: 8 additions & 3 deletions src/zenml/materializers/structured_string_materializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,23 @@
from zenml.enums import ArtifactType, VisualizationType
from zenml.logger import get_logger
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.types import CSVString, HTMLString, MarkdownString
from zenml.types import CSVString, HTMLString, JSONString, MarkdownString

logger = get_logger(__name__)


STRUCTURED_STRINGS = Union[CSVString, HTMLString, MarkdownString]
STRUCTURED_STRINGS = Union[CSVString, HTMLString, MarkdownString, JSONString]

HTML_FILENAME = "output.html"
MARKDOWN_FILENAME = "output.md"
CSV_FILENAME = "output.csv"
JSON_FILENAME = "output.json"


class StructuredStringMaterializer(BaseMaterializer):
"""Materializer for HTML or Markdown strings."""

ASSOCIATED_TYPES = (CSVString, HTMLString, MarkdownString)
ASSOCIATED_TYPES = (CSVString, HTMLString, MarkdownString, JSONString)
ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA_ANALYSIS

def load(self, data_type: Type[STRUCTURED_STRINGS]) -> STRUCTURED_STRINGS:
Expand Down Expand Up @@ -94,6 +95,8 @@ def _get_filepath(self, data_type: Type[STRUCTURED_STRINGS]) -> str:
filename = HTML_FILENAME
elif issubclass(data_type, MarkdownString):
filename = MARKDOWN_FILENAME
elif issubclass(data_type, JSONString):
filename = JSON_FILENAME
else:
raise ValueError(
f"Data type {data_type} is not supported by this materializer."
Expand All @@ -120,6 +123,8 @@ def _get_visualization_type(
return VisualizationType.HTML
elif issubclass(data_type, MarkdownString):
return VisualizationType.MARKDOWN
elif issubclass(data_type, JSONString):
return VisualizationType.JSON
else:
raise ValueError(
f"Data type {data_type} is not supported by this materializer."
Expand Down
3 changes: 3 additions & 0 deletions src/zenml/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,6 @@ class MarkdownString(str):

class CSVString(str):
"""Special string class to indicate a CSV string."""

class JSONString(str):
"""Special string class to indicate a JSON string."""
5 changes: 4 additions & 1 deletion src/zenml/utils/visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# permissions and limitations under the License.
"""Utility functions for dashboard visualizations."""

import json
from typing import TYPE_CHECKING, Optional

from IPython.core.display import HTML, Image, Markdown, display
from IPython.core.display import HTML, Image, JSON, Markdown, display

from zenml.artifacts.utils import load_artifact_visualization
from zenml.enums import VisualizationType
Expand Down Expand Up @@ -63,6 +64,8 @@ def visualize_artifact(
assert isinstance(visualization.value, str)
table = format_csv_visualization_as_html(visualization.value)
display(HTML(table))
elif visualization.type == VisualizationType.JSON:
display(JSON(json.loads(visualization.value)))
else:
display(visualization.value)

Expand Down
Loading