Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix packaging with weights format priority #585

Merged
merged 7 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ Made with [contrib.rocks](https://contrib.rocks).

### bioimageio.spec Python package

#### bioimageio.spec 0.5.2post1

* fix model packaging with weights format priority

#### bioimageio.spec 0.5.2

* new patch version model 0.5.2
Expand Down
2 changes: 1 addition & 1 deletion bioimageio/spec/VERSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"version": "0.5.2"
"version": "0.5.2post1"
}
2 changes: 1 addition & 1 deletion bioimageio/spec/_internal/common_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class StringNode(collections.UserString, ABC):
_node_class: Type[Node]
_node: Optional[Node] = None

def __init__(self: Self, seq: object) -> None:
def __init__(self, seq: object) -> None:
super().__init__(seq)
type_hints = {
fn: t
Expand Down
17 changes: 13 additions & 4 deletions bioimageio/spec/_internal/packaging_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from contextvars import ContextVar, Token
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
from typing import Dict, List, Literal, Optional, Sequence, Union

from .io_basics import AbsoluteFilePath, FileName
from .url import HttpUrl
Expand All @@ -16,16 +16,20 @@ class PackagingContext:

bioimageio_yaml_file_name: FileName

file_sources: Dict[FileName, Union[AbsoluteFilePath, HttpUrl]] = field(
default_factory=dict
)
file_sources: Dict[FileName, Union[AbsoluteFilePath, HttpUrl]]
"""File sources to include in the packaged resource"""

weights_priority_order: Optional[Sequence[str]] = None
"""set to select a single weights entry when packaging model resources"""

def replace(
self,
*,
bioimageio_yaml_file_name: Optional[FileName] = None,
file_sources: Optional[Dict[FileName, Union[AbsoluteFilePath, HttpUrl]]] = None,
weights_priority_order: Union[
Optional[Sequence[str]], Literal["unchanged"]
] = "unchanged",
) -> "PackagingContext":
"""return a modiefied copy"""
return PackagingContext(
Expand All @@ -37,6 +41,11 @@ def replace(
file_sources=(
dict(self.file_sources) if file_sources is None else file_sources
),
weights_priority_order=(
self.weights_priority_order
if weights_priority_order == "unchanged"
else weights_priority_order
),
)

def __enter__(self):
Expand Down
42 changes: 3 additions & 39 deletions bioimageio/spec/_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
from ._internal.validation_context import validation_context_var
from ._internal.warning_levels import ERROR
from ._io import load_description
from .model.v0_4 import ModelDescr as ModelDescr04
from .model.v0_4 import WeightsFormat
from .model.v0_5 import ModelDescr as ModelDescr05


def get_os_friendly_file_name(name: str) -> str:
Expand Down Expand Up @@ -59,50 +57,16 @@ def get_resource_package_content(
)
content: Dict[FileName, Union[HttpUrl, AbsoluteFilePath]] = {}
with PackagingContext(
bioimageio_yaml_file_name=bioimageio_yaml_file_name, file_sources=content
bioimageio_yaml_file_name=bioimageio_yaml_file_name,
file_sources=content,
weights_priority_order=weights_priority_order,
):
rdf_content: BioimageioYamlContent = rd.model_dump(
mode="json", exclude_unset=True
)

_ = rdf_content.pop("rdf_source", None)

if weights_priority_order is not None and isinstance(
rd, (ModelDescr04, ModelDescr05)
):
# select single weights entry
assert isinstance(rdf_content["weights"], dict), type(rdf_content["weights"])
for wf in weights_priority_order:
w = rdf_content["weights"].get(wf)
if w is not None:
break
else:
raise ValueError(
"None of the weight formats in `weights_priority_order` is present in"
+ " the given model."
)

assert isinstance(w, dict), type(w)
_ = w.pop("parent", None)
rdf_content["weights"] = {wf: w}
parent = rdf_content.pop("id", None)
parent_version = rdf_content.pop("version", None)
if parent is not None:
rdf_content["parent"] = {"id": parent, "version": parent_version}

with validation_context_var.get().replace(
root=rd.root, file_name=bioimageio_yaml_file_name
):
rd_slim = build_description(rdf_content)

assert not isinstance(
rd_slim, InvalidDescr
), rd_slim.validation_summary.format()
# repackage without other weights entries
return get_resource_package_content(
rd_slim, bioimageio_yaml_file_name=bioimageio_yaml_file_name
)

return {**content, bioimageio_yaml_file_name: rdf_content}


Expand Down
41 changes: 39 additions & 2 deletions bioimageio/spec/model/v0_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
AllowInfNan,
Discriminator,
Field,
SerializationInfo,
SerializerFunctionWrapHandler,
TypeAdapter,
ValidationInfo,
WrapSerializer,
field_validator,
model_validator,
)
Expand All @@ -46,6 +49,7 @@
from .._internal.io import FileDescr as FileDescr
from .._internal.io import Sha256 as Sha256
from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath
from .._internal.packaging_context import packaging_context_var
from .._internal.types import Datetime as Datetime
from .._internal.types import Identifier as Identifier
from .._internal.types import ImportantFileSource, LowerCaseIdentifier
Expand Down Expand Up @@ -870,6 +874,37 @@ class LinkedModel(Node):
"""version number (n-th published version, not the semantic version) of linked model"""


def package_weights(
value: Node,
handler: SerializerFunctionWrapHandler,
info: SerializationInfo,
):
ctxt = packaging_context_var.get()
if ctxt is not None and ctxt.weights_priority_order is not None:
for wf in ctxt.weights_priority_order:
w = getattr(value, wf, None)
if w is not None:
break
else:
raise ValueError(
"None of the weight formats in `weights_priority_order`"
+ f" ({ctxt.weights_priority_order}) is present in the given model."
)

# remove links to parent entry (otherwise we cannot remove the parent)
for _, w in value:
if w is not None:
w.parent = None

for field_name in value.model_fields:
if field_name != wf:
setattr(value, field_name, None)

return handler(
value, info # pyright: ignore[reportArgumentType] # taken from pydantic docs
)


class ModelDescr(GenericModelDescrBase, title="bioimage.io model specification"):
"""Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights.

Expand All @@ -888,7 +923,9 @@ class ModelDescr(GenericModelDescrBase, title="bioimage.io model specification")
id: Optional[ModelId] = None
"""Model zoo (bioimage.io) wide, unique identifier (assigned by bioimage.io)"""

authors: NotEmpty[List[Author]]
authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory
List[Author]
]
"""The authors are the creators of the model RDF and the primary points of contact."""

documentation: Annotated[
Expand Down Expand Up @@ -1114,7 +1151,7 @@ def ignore_url_parent(cls, parent: Any):
training_data: Union[LinkedDataset, DatasetDescr, None] = None
"""The dataset used to train this model"""

weights: WeightsDescr
weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
"""The weights for this model.
Weights can be given for different formats, but should otherwise be equivalent.
The available weight formats determine which consumers can use this model."""
Expand Down
4 changes: 3 additions & 1 deletion bioimageio/spec/model/v0_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
RootModel,
Tag,
ValidationInfo,
WrapSerializer,
field_validator,
model_validator,
)
Expand Down Expand Up @@ -123,6 +124,7 @@
from .v0_4 import TensorName as _TensorName_v0_4
from .v0_4 import WeightsFormat as WeightsFormat
from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
from .v0_4 import package_weights

# unit names from https://ngff.openmicroscopy.org/latest/#axes-md
SpaceUnit = Literal[
Expand Down Expand Up @@ -2342,7 +2344,7 @@ def _validate_output_axes(
training_data: Union[None, LinkedDataset, DatasetDescr] = None
"""The dataset used to train this model"""

weights: WeightsDescr
weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
"""The weights for this model.
Weights can be given for different formats, but should otherwise be equivalent.
The available weight formats determine which consumers can use this model."""
Expand Down
Loading