diff --git a/scripts/interactive_docs/hint.py b/scripts/interactive_docs/hint.py index 359661e0f..3c799e331 100644 --- a/scripts/interactive_docs/hint.py +++ b/scripts/interactive_docs/hint.py @@ -1,14 +1,15 @@ -from abc import ABC, abstractmethod -from dataclasses import dataclass -from collections.abc import Mapping, Sequence -from xml.etree import ElementTree as et +import datetime +import inspect import json -import yaml +import types import typing +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from dataclasses import dataclass from typing import ( + Any, ClassVar, Dict, - Any, Final, ForwardRef, Literal, @@ -19,18 +20,16 @@ cast, final, ) -from typing_extensions import TypeAliasType, assert_never -from typing_extensions import List, TypeAlias -import datetime from xml.etree import ElementTree as et -import inspect -from annotated_types import Predicate import pydantic +import typing_extensions +import yaml +from annotated_types import Predicate from pydantic import BaseModel from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined, PydanticUndefinedType -import typing_extensions +from typing_extensions import List, TypeAlias, TypeAliasType, assert_never from bioimageio.spec._internal.io import YamlValue, is_yaml_leaf_value, is_yaml_value @@ -165,13 +164,19 @@ def get_subclasses() -> Sequence[Type["Hint"]]: @final @classmethod def parse( - cls, *, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + *, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "Hint | Unrecognized | ParsingError": # if raw_hint in cls.hint_cache: # return cls.hint_cache[raw_hint] #FIXME: maybe move this into the individual do_parse impls? hint: "Hint | Unrecognized | Exception" = Unrecognized(raw_hint=raw_hint) for subclass in Hint.get_subclasses(): - hint = subclass.do_parse(raw_hint, parent_raw_hints=parent_raw_hints) + hint = subclass.do_parse( + raw_hint, parent_raw_hints=parent_raw_hints, discriminator=discriminator + ) if isinstance(hint, ParsingError): return hint if isinstance(hint, Unrecognized): @@ -186,7 +191,10 @@ def __init__(self) -> None: @classmethod @abstractmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "Hint | Unrecognized | ParsingError": raise NotImplementedError @@ -208,7 +216,10 @@ def get_example(self) -> "Example | Exception": class YamlValueHint(Hint): @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "Hint | Unrecognized | ParsingError": # FIXME: since the spec is yaml, "Any" mostly translates to YamlValue.... but is this always true? if raw_hint == typing.Any: @@ -247,7 +258,10 @@ def __init__(self, raw_hint: Any) -> None: @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "Hint | Unrecognized | ParsingError": if raw_hint not in parent_raw_hints: return Unrecognized(raw_hint=raw_hint) @@ -280,7 +294,10 @@ def __init__(self, pattern: str) -> None: @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "Hint | Unrecognized | ParsingError": if not inspect.isclass(raw_hint) or not any( klass.__name__ == "StringNode" for klass in raw_hint.__mro__ @@ -312,7 +329,10 @@ def get_example(self) -> "Example | Exception": class RootModelHint(Hint): @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "Hint | Unrecognized | ParsingError": from pydantic import RootModel @@ -343,13 +363,25 @@ def __init__(self, inner_hint: Hint, restrictions: List[str]) -> None: @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "Hint | Unrecognized | ParsingError": if raw_hint.__class__ != typing_extensions.Annotated[int, None].__class__: return Unrecognized(raw_hint) + + inner_hint: "Hint | Unrecognized | ParsingError" + discri: Optional[pydantic.Discriminator] = discriminator + for md in raw_hint.__metadata__: + if not isinstance(md, pydantic.Discriminator): + continue + discri = md + break inner_hint = Hint.parse( raw_hint=raw_hint.__args__[0], parent_raw_hints=[*parent_raw_hints, raw_hint], + discriminator=discri, ) if isinstance(inner_hint, (ParsingError, Unrecognized)): return inner_hint.with_context(f"Could not parse inner hint for {raw_hint}") @@ -376,6 +408,12 @@ def do_parse( if "PydanticGeneralMetadata" in type(md).__name__: continue if isinstance(md, Predicate): + if isinstance(md.func, types.LambdaType): + try: + metadata.append(inspect.getsource(md.func).strip()) + except OSError: + eprint("WARNING: could not get lambda source") + continue metadata_str = md.func.__name__ if md.func.__doc__: metadata_str += f": {md.func.__doc__}" @@ -404,7 +442,7 @@ def short_description(self, extra: Sequence["Widget"] = ()) -> "Widget": ) def get_example(self) -> "Example | Exception": - eprint(f"WARNING: Annotated type without a manually provided example") + eprint("WARNING: Annotated type without a manually provided example") return self.inner_hint.get_example() def to_type_widget( @@ -448,12 +486,17 @@ def __init__(self, name: str, inner: Hint) -> None: @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "Hint | Unrecognized | ParsingError": if not isinstance(raw_hint, TypeAliasType): return Unrecognized(raw_hint=raw_hint) inner = Hint.parse( - raw_hint=raw_hint.__value__, parent_raw_hints=[*parent_raw_hints, raw_hint] + raw_hint=raw_hint.__value__, + parent_raw_hints=[*parent_raw_hints, raw_hint], + discriminator=discriminator, ) if isinstance(inner, (Unrecognized, ParsingError)): return inner.with_context( @@ -465,7 +508,7 @@ def short_description(self, extra: Sequence["Widget"] = ()) -> "Widget": # fmt: off return Widget("span", children=[ InlinePre(text=f"{self.name}"), - Widget("span", text=f" (Alias)", style="font-style: italic; opacity: 0.6"), + Widget("span", text=" (Alias)", style="font-style: italic; opacity: 0.6"), *extra ]) # fmt: on @@ -485,7 +528,7 @@ def to_type_widget( self.short_description(extra=extra_summary) ]), self.inner.to_type_widget(path=path, extra_summary=[ - Widget("span", text=f" (Aliased)", style="font-style: italic; opacity: 0.6"), + Widget("span", text=" (Aliased)", style="font-style: italic; opacity: 0.6"), ]) ]) # fmt: on @@ -494,7 +537,10 @@ def to_type_widget( class DatetimeHint(Hint): @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "Hint | Unrecognized | ParsingError": if raw_hint != datetime.datetime: return Unrecognized(raw_hint=raw_hint) @@ -520,7 +566,10 @@ def to_type_widget( class DateHint(Hint): @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "Hint | Unrecognized | ParsingError": if raw_hint != datetime.date: return Unrecognized(raw_hint=raw_hint) @@ -546,7 +595,10 @@ def to_type_widget( class PathHint(Hint): @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "Hint | Unrecognized | ParsingError": from pathlib import Path, PurePath @@ -574,7 +626,10 @@ def to_type_widget( class EmailHint(Hint): @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "Hint | Unrecognized | ParsingError": from pydantic.networks import EmailStr @@ -602,7 +657,10 @@ def to_type_widget( class UrlHint(Hint): @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "Hint | Unrecognized | ParsingError": from pydantic import AnyUrl @@ -648,7 +706,10 @@ def is_mapping_hint( @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "MappingHint | Unrecognized | ParsingError": if not cls.is_mapping_hint(raw_hint): return Unrecognized(raw_hint) @@ -717,7 +778,10 @@ def __init__(self, values: Sequence["int | float | bool | str | None"]): @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "LiteralHint | Unrecognized | ParsingError": some_dummy_literal_hint = Literal["a"] if raw_hint.__class__ != some_dummy_literal_hint.__class__: @@ -784,14 +848,19 @@ def __init__( self, model: Type["BaseModel"], fields: Mapping[str, Tuple[Hint, Example]], + discriminator: Optional[pydantic.Discriminator], ): self.model = model self.fields = fields + self.discriminator = discriminator super().__init__() @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "Hint | Unrecognized | ParsingError": if not inspect.isclass(raw_hint) or not issubclass(raw_hint, BaseModel): return Unrecognized(raw_hint) @@ -804,15 +873,14 @@ def do_parse( ) for field_name, field_info in required_fields_first: - # if field_name == "license": - # import pydevd; pydevd.settrace() - # eprint("asdasd") field_descriptor = f"{raw_hint.__name__}.{field_name}" field_hint = Hint.parse( raw_hint=get_field_annotation(field_info), # raw_hint=typing.get_type_hints(raw_hint, include_extras=True)[field_name], parent_raw_hints=[*parent_raw_hints, raw_hint], + discriminator=None, # discard discriminator as it only applies to the current ModelHint ) + if isinstance(field_hint, (ParsingError, Unrecognized)): return field_hint.with_context( f"Could not parse type of field {field_descriptor}" @@ -836,7 +904,7 @@ def do_parse( fields[field_name] = (field_hint, field_example) - return ModelHint(model=raw_hint, fields=fields) + return ModelHint(model=raw_hint, fields=fields, discriminator=discriminator) def short_description(self, extra: Sequence["Widget"] = ()) -> "Widget": # fmt: off @@ -864,6 +932,8 @@ def to_type_widget( if not isinstance(field_default, PydanticUndefinedType): field_default = Example.try_from_value(field_default) assert not isinstance(field_default, Exception) + if self.discriminator and self.discriminator.discriminator == field_name: + field_default = PydanticUndefined fields.append( FieldData( name=field_name, @@ -893,9 +963,12 @@ def __init__(self, hint_type: PrimitiveType): @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "PrimitiveHint | Unrecognized | ParsingError": - if raw_hint == None: + if raw_hint is None: raw_hint = type(None) if not inspect.isclass(raw_hint) or not issubclass( raw_hint, (int, float, bool, str, type(None)) @@ -906,7 +979,7 @@ def do_parse( def short_description(self, extra: Sequence["Widget"] = ()) -> "Widget": # fmt: off return Widget("span", children=[ - InlinePre(text="null" if self.hint_type == type(None) else self.hint_type.__name__), + InlinePre(text="null" if self.hint_type is type(None) else self.hint_type.__name__), *extra ]) # fmt: on @@ -953,7 +1026,10 @@ def __init__( @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "NTuple | Unrecognized | ParsingError": if not is_tuple_hint(raw_hint) or (... in raw_hint.__args__): return Unrecognized(raw_hint) @@ -1021,7 +1097,10 @@ def __init__(self, element_type: Hint, element_example: Example): @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "VarLenTuple | Unrecognized | ParsingError": if not is_tuple_hint(raw_hint): return Unrecognized(raw_hint=raw_hint) @@ -1098,7 +1177,10 @@ def is_list_hint( @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "ListHint | Unrecognized | ParsingError": if not cls.is_list_hint(raw_hint): return Unrecognized(raw_hint=raw_hint) @@ -1107,7 +1189,7 @@ def do_parse( parent_raw_hints=[*parent_raw_hints, raw_hint], ) if isinstance(element_hint, (Unrecognized, ParsingError)): - return element_hint.with_context(f"Could not parse List element type") + return element_hint.with_context("Could not parse List element type") element_example = element_hint.get_example() if isinstance(element_example, Exception): return ParsingError( @@ -1155,7 +1237,10 @@ def __init__(self, args: Sequence[Tuple[Hint, Example]]): @classmethod def do_parse( - cls, raw_hint: Any, parent_raw_hints: Sequence[Any] + cls, + raw_hint: Any, + parent_raw_hints: Sequence[Any], + discriminator: Optional[pydantic.Discriminator] = None, ) -> "UnionHint | Unrecognized | ParsingError": some_dummy_union = Union[int, str] if raw_hint.__class__ != some_dummy_union.__class__: @@ -1164,7 +1249,9 @@ def do_parse( union_args: List[Tuple[Hint, Example]] = [] for arg_idx, arg in enumerate(raw_hint.__args__): hint = Hint.parse( - raw_hint=arg, parent_raw_hints=[*parent_raw_hints, raw_hint] + raw_hint=arg, + parent_raw_hints=[*parent_raw_hints, raw_hint], + discriminator=discriminator, ) if isinstance(hint, (Unrecognized, ParsingError)): return hint.with_context(