diff --git a/isic_metadata/metadata.py b/isic_metadata/metadata.py index 017682c..45c4726 100644 --- a/isic_metadata/metadata.py +++ b/isic_metadata/metadata.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import defaultdict import functools from typing import Any, Callable, Optional, Union @@ -13,6 +14,7 @@ field_validator, model_validator, ) +from pydantic_core import PydanticCustomError from typing_extensions import Annotated from isic_metadata.fields import ( @@ -45,6 +47,37 @@ def EnumErrorMessageValidator(enum, field_name: str): # noqa: N802 return WrapValidator(functools.partial(validate_enum_message, field_name)) +class MetadataBatch(BaseModel): + """ + A batch of metadata rows. + + This is useful for performing checks that span across multiple rows. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + items: list[MetadataRow] + + @model_validator(mode="after") + def check_patients_lesions(self) -> "MetadataBatch": + lesion_to_patients: dict[str, set[str]] = defaultdict(set) + + for item in self.items: + if item.patient_id and item.lesion_id: + lesion_to_patients[item.lesion_id].add(item.patient_id) + + bad_lesions = [ + lesion for lesion in lesion_to_patients if len(lesion_to_patients[lesion]) > 1 + ] + if bad_lesions: + raise PydanticCustomError( + "one_lesion_multiple_patients", + "One or more lesions belong to multiple patients", + {"examples": bad_lesions}, + ) + + return self + + class MetadataRow(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/tests/test_batch.py b/tests/test_batch.py new file mode 100644 index 0000000..384921c --- /dev/null +++ b/tests/test_batch.py @@ -0,0 +1,20 @@ +from pydantic import ValidationError +import pytest + +from isic_metadata.metadata import MetadataBatch, MetadataRow + + +def test_batch(): + MetadataBatch(items=[MetadataRow(diagnosis="melanoma"), MetadataRow(diagnosis="melanoma")]) + + +def test_lesions_belong_to_same_patient(): + with pytest.raises(ValidationError) as excinfo: + MetadataBatch( + items=[ + MetadataRow(lesion_id="foo", patient_id="foopatient"), + MetadataRow(lesion_id="foo", patient_id="barpatient"), + ] + ) + assert len(excinfo.value.errors()) == 1 + assert "belong to multiple patients" in excinfo.value.errors()[0]["msg"]