Skip to content

Commit

Permalink
more accurate YAML types
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Aug 30, 2023
1 parent 8781cb4 commit 7ea0c07
Show file tree
Hide file tree
Showing 22 changed files with 246 additions and 249 deletions.
19 changes: 7 additions & 12 deletions bioimageio/spec/_internal/base_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
from abc import ABC
from collections import UserString
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Generic,
Iterator,
Literal,
Mapping,
Optional,
Set,
Expand All @@ -33,12 +31,9 @@
from typing_extensions import Annotated, Self

from bioimageio.spec._internal.constants import IN_PACKAGE_MESSAGE
from bioimageio.spec._internal.field_validation import ValContext, get_validation_context, is_valid_raw_mapping
from bioimageio.spec._internal.field_validation import ValContext, get_validation_context, is_valid_yaml_mapping
from bioimageio.spec._internal.utils import unindent
from bioimageio.spec.types import NonEmpty, RawStringDict, RawValue

if TYPE_CHECKING:
from pydantic.main import IncEx
from bioimageio.spec.types import NonEmpty, YamlMapping, YamlValue

K = TypeVar("K", bound=str)
V = TypeVar("V")
Expand Down Expand Up @@ -170,7 +165,7 @@ def _update_context_and_data(cls, context: ValContext, data: Dict[Any, Any]) ->
cls.convert_from_older_format(data, context)

@classmethod
def convert_from_older_format(cls, data: RawStringDict, context: ValContext) -> None:
def convert_from_older_format(cls, data: YamlMapping, context: ValContext) -> None:
"""A node may `convert` it's raw data from an older format."""
pass

Expand Down Expand Up @@ -304,15 +299,15 @@ def get(self, item: Any, default: D = None) -> Union[V, D]: # type: ignore

@model_validator(mode="after")
def validate_raw_mapping(self) -> Self:
if not is_valid_raw_mapping(self):
raise AssertionError(f"{self} contains values unrepresentable in JSON/YAML")
if not is_valid_yaml_mapping(self):
raise AssertionError(f"{self} contains values unrepresentable in YAML")

return self


class ConfigNode(FrozenDictNode[NonEmpty[str], RawValue]):
class ConfigNode(FrozenDictNode[NonEmpty[str], YamlValue]):
model_config = {**Node.model_config, "extra": "allow"}


class Kwargs(FrozenDictNode[NonEmpty[str], RawValue]):
class Kwargs(FrozenDictNode[NonEmpty[str], YamlValue]):
model_config = {**Node.model_config, "extra": "allow"}
132 changes: 116 additions & 16 deletions bioimageio/spec/_internal/field_validation.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,106 @@
from __future__ import annotations

import collections.abc
import dataclasses
import pathlib
import re
from dataclasses import dataclass
from datetime import datetime
from keyword import iskeyword
from pathlib import Path, PurePath
from pathlib import Path, PurePath, PurePosixPath
from typing import TYPE_CHECKING, Any, Dict, Hashable, Mapping, Sequence, Tuple, Type, TypeVar, Union, get_args
from urllib.parse import urljoin

import annotated_types
import packaging.version
from dateutil.parser import isoparse
from pydantic import AnyUrl, DirectoryPath, GetCoreSchemaHandler, functional_validators
from pydantic import AnyUrl, DirectoryPath, FilePath, GetCoreSchemaHandler, ValidationInfo, functional_validators
from pydantic_core import core_schema
from pydantic_core.core_schema import CoreSchema, no_info_after_validator_function
from typing_extensions import NotRequired, TypedDict

from bioimageio.spec._internal.constants import ERROR, SLOTS

if TYPE_CHECKING:
from bioimageio.spec.types import FileSource, RelativePath, WarningLevel
from bioimageio.spec.types import FileSource, WarningLevel


class RelativePath:
path: PurePosixPath

def __init__(self, path: Union[str, Path, RelativePath]) -> None:
super().__init__()
self.path = (
path.path
if isinstance(path, RelativePath)
else PurePosixPath(path.as_posix())
if isinstance(path, Path)
else PurePosixPath(Path(path).as_posix())
)

@property
def __members(self):
return (self.path,)

def __eq__(self, __value: object) -> bool:
return type(__value) is type(self) and self.__members == __value.__members

def __hash__(self) -> int:
return hash(self.__members)

def __str__(self) -> str:
return self.path.as_posix()

def __repr__(self) -> str:
return f"RelativePath('{self.path.as_posix()}')"

@classmethod
def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
return core_schema.general_after_validator_function(
cls._validate,
core_schema.union_schema(
[
core_schema.is_instance_schema(cls),
core_schema.is_instance_schema(pathlib.Path),
core_schema.str_schema(),
]
),
serialization=core_schema.to_string_ser_schema(),
)

def get_absolute(self, root: Union[DirectoryPath, AnyUrl]) -> Union[FilePath, AnyUrl]:
if isinstance(root, pathlib.Path):
return root / self.path
else:
return AnyUrl(urljoin(str(root), str(self.path)))

def _check_exists(self, root: Union[DirectoryPath, AnyUrl]) -> None:
if isinstance((p := self.get_absolute(root)), pathlib.Path) and not p.exists():
raise ValueError(f"{p} does not exist")

@classmethod
def _validate(cls, value: Union[pathlib.Path, str], info: ValidationInfo):
if isinstance(value, str) and (value.startswith("https://") or value.startswith("http://")):
raise ValueError(f"{value} looks like a URL, not a relative path")

ret = cls(value)
root = (info.context or {}).get("root")
if root is not None:
ret._check_exists(root)

return ret


class RelativeFilePath(RelativePath):
def _check_exists(self, root: Union[DirectoryPath, AnyUrl]) -> None:
if isinstance((p := self.get_absolute(root)), pathlib.Path) and not p.is_file():
raise ValueError(f"{p} does not point to an existing file")


class RelativeDirectory(RelativePath):
def _check_exists(self, root: Union[DirectoryPath, AnyUrl]) -> None:
if isinstance((p := self.get_absolute(root)), pathlib.Path) and not p.is_dir():
raise ValueError(f"{p} does not point to an existing directory")


@dataclasses.dataclass(frozen=True, **SLOTS)
Expand Down Expand Up @@ -47,7 +130,7 @@ class WithSuffix:
case_sensitive: bool

def __get_pydantic_core_schema__(self, source: Type[Any], handler: GetCoreSchemaHandler) -> CoreSchema:
from bioimageio.spec.types import FileSource, RelativePath
from bioimageio.spec.types import FileSource

if not self.suffix:
raise ValueError("suffix may not be empty")
Expand Down Expand Up @@ -78,7 +161,7 @@ def validate_datetime(dt: Union[datetime, str, Any]) -> datetime:
elif isinstance(dt, str):
return isoparse(dt)

raise AssertionError(f"'{dt}' not a string or datetime.")
raise ValueError(f"'{dt}' not a string or datetime.")


def validate_identifier(s: str) -> str:
Expand Down Expand Up @@ -110,24 +193,30 @@ def validate_orcid_id(orcid_id: str):
raise ValueError(f"'{orcid_id} is not a valid ORCID iD in hyphenated groups of 4 digits.")


def is_valid_raw_leaf_value(value: Any) -> bool:
from bioimageio.spec.types import RawLeafValue
def is_valid_yaml_leaf_value(value: Any) -> bool:
from bioimageio.spec.types import YamlLeafValue

return isinstance(value, get_args(RawLeafValue))
return isinstance(value, get_args(YamlLeafValue))


def is_valid_raw_mapping(value: Union[Any, Mapping[Any, Any]]) -> bool:
def is_valid_yaml_key(value: Union[Any, Sequence[Any]]) -> bool:
return (
is_valid_yaml_leaf_value(value) or isinstance(value, tuple) and all(is_valid_yaml_leaf_value(v) for v in value)
)


def is_valid_yaml_mapping(value: Union[Any, Mapping[Any, Any]]) -> bool:
return isinstance(value, collections.abc.Mapping) and all(
isinstance(k, str) and is_valid_raw_value(v) for k, v in value.items()
is_valid_yaml_key(k) and is_valid_yaml_value(v) for k, v in value.items()
)


def is_valid_raw_sequence(value: Union[Any, Sequence[Any]]) -> bool:
return isinstance(value, collections.abc.Sequence) and all(is_valid_raw_value(v) for v in value)
def is_valid_yaml_sequence(value: Union[Any, Sequence[Any]]) -> bool:
return isinstance(value, collections.abc.Sequence) and all(is_valid_yaml_value(v) for v in value)


def is_valid_raw_value(value: Any) -> bool:
return any(is_valid(value) for is_valid in (is_valid_raw_leaf_value, is_valid_raw_mapping, is_valid_raw_sequence))
def is_valid_yaml_value(value: Any) -> bool:
return any(is_valid(value) for is_valid in (is_valid_yaml_key, is_valid_yaml_mapping, is_valid_yaml_sequence))


V_suffix = TypeVar("V_suffix", bound=Union[AnyUrl, PurePath, "RelativePath"])
Expand Down Expand Up @@ -175,13 +264,24 @@ def validate_version(v: str) -> str:
return v


class ValidationContext(TypedDict):
root: NotRequired[Union[DirectoryPath, AnyUrl]]
"""url/path serving as base to any relative file paths. Default provided as data field `root`.0"""

file_name: NotRequired[str]
"""The file name of the RDF used only for reporting"""

warning_level: NotRequired[WarningLevel]
"""raise warnings of severity s as validation errors if s >= `warning_level`"""


class ValContext(TypedDict):
"""internally used validation context"""

root: Union[DirectoryPath, AnyUrl]
"""url/path serving as base to any relative file paths. Default provided as data field `root`.0"""

warning_level: "WarningLevel"
warning_level: WarningLevel
"""raise warnings of severity s as validation errors if s >= `warning_level`"""

file_name: str
Expand All @@ -197,7 +297,7 @@ class ValContext(TypedDict):
def get_validation_context(
*,
root: Union[DirectoryPath, AnyUrl] = Path(),
warning_level: "WarningLevel" = ERROR,
warning_level: WarningLevel = ERROR,
file_name: str = "rdf.yaml",
**kwargs: Any,
) -> ValContext:
Expand Down
11 changes: 6 additions & 5 deletions bioimageio/spec/collection/v0_2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import collections.abc
from typing import Any, ClassVar, Dict, Literal, Optional, Tuple, Union

from pydantic import ConfigDict, Field, HttpUrl, PrivateAttr, TypeAdapter, field_validator, model_validator
from pydantic import model_validator # type: ignore
from pydantic import ConfigDict, Field, HttpUrl, PrivateAttr, TypeAdapter, field_validator
from pydantic_core import PydanticUndefined
from pydantic_core.core_schema import ValidationInfo
from typing_extensions import Annotated, Self
Expand All @@ -16,7 +17,7 @@
from bioimageio.spec.generic.v0_2 import GenericBase
from bioimageio.spec.model.v0_4 import Model
from bioimageio.spec.notebook.v0_2 import Notebook
from bioimageio.spec.types import NonEmpty, RawStringDict, RawValue, RelativeFilePath
from bioimageio.spec.types import NonEmpty, RelativeFilePath, YamlMapping, YamlValue

__all__ = [
"Attachments",
Expand Down Expand Up @@ -50,7 +51,7 @@ class CollectionEntryBase(Node):
The full collection entry's id is the collection's base id, followed by this sub id and separated by a slash '/'."""

@property
def rdf_update(self) -> Dict[str, RawValue]:
def rdf_update(self) -> Dict[str, YamlValue]:
return self.model_extra or {}

@property
Expand Down Expand Up @@ -172,7 +173,7 @@ def _update_context_and_data(cls, context: ValContext, data: Dict[Any, Any]) ->
context["collection_base_content"] = collection_base_content

@staticmethod
def move_groups_to_collection_field(data: RawStringDict) -> None:
def move_groups_to_collection_field(data: YamlMapping) -> None:
if data.get("format_version") not in ("0.2.0", "0.2.1"):
return

Expand All @@ -196,6 +197,6 @@ def move_groups_to_collection_field(data: RawStringDict) -> None:
data["id"] = id_

@classmethod
def convert_from_older_format(cls, data: RawStringDict, context: ValContext) -> None:
def convert_from_older_format(cls, data: YamlMapping, context: ValContext) -> None:
cls.move_groups_to_collection_field(data)
super().convert_from_older_format(data, context)
4 changes: 2 additions & 2 deletions bioimageio/spec/collection/v0_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from bioimageio.spec.model.v0_5 import Model as Model05
from bioimageio.spec.notebook.v0_2 import Notebook as Notebook02
from bioimageio.spec.notebook.v0_3 import Notebook as Notebook03
from bioimageio.spec.types import NonEmpty, RawStringDict
from bioimageio.spec.types import NonEmpty, YamlMapping

__all__ = [
"Attachments",
Expand Down Expand Up @@ -89,6 +89,6 @@ def check_unique_ids(cls, value: NonEmpty[Tuple[CollectionEntry, ...]]) -> NonEm
return value

@classmethod
def convert_from_older_format(cls, data: RawStringDict, context: ValContext) -> None:
def convert_from_older_format(cls, data: YamlMapping, context: ValContext) -> None:
v0_2.Collection.move_groups_to_collection_field(data)
super().convert_from_older_format(data, context)
Loading

0 comments on commit 7ea0c07

Please sign in to comment.