Skip to content

Commit

Permalink
Merge pull request #43 from ImageMarkup/fix-hierarchical-diagnosis-bug
Browse files Browse the repository at this point in the history
Make hierarchical diagnosis support passing in multiple values
  • Loading branch information
danlamanna authored Oct 24, 2024
2 parents 8493fc5 + 3d51cbb commit d742154
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 42 deletions.
26 changes: 24 additions & 2 deletions isic_metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,32 @@ def __init__(self, **kwargs) -> None:

super().__init__(**kwargs)

# See https://github.com/samuelcolvin/pydantic/issues/2285 for more detail
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: dict[str, Any]) -> dict[str, Any]:
def handle_hierarchical_diagnosis_modes_and_unstructured_fields(
cls, values: dict[str, Any]
) -> dict[str, Any]:
"""
Handle the case where hierarchical diagnosis values are passed in as multiple fields.
Practically, ingesting data should never pass in multiple values but instead use the
colon-separated `diagnosis` field. This method is provided for the scenario where
data needs to be retrieved from the database (where it's stored multi-valued) and
revalidated. This method also handles putting any unrecognized fields into an unstructured
field. Unfortunately, pydantic doesn't yet support ordering different model validators so
these both need to be combined into one method.
"""
using_diagnoses_multi_values = any(f"diagnosis_{i}" in values for i in range(1, 6))
using_diagnosis_single_value = bool(values.get("diagnosis"))

if using_diagnoses_multi_values and using_diagnosis_single_value:
[values.pop(f"diagnosis_{i}", "") for i in range(1, 6)]
elif using_diagnoses_multi_values:
values["diagnosis"] = ":".join(values.pop(f"diagnosis_{i}", "") for i in range(1, 6))
values["diagnosis"] = values["diagnosis"].rstrip(":")

# handle unstructured fields
# See https://github.com/samuelcolvin/pydantic/issues/2285 for more detail
structured_field_names = {field for field in cls.model_fields if field != "unstructured"}

unstructured: dict[str, Any] = {}
Expand Down
40 changes: 0 additions & 40 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,43 +124,3 @@ def test_clin_size_long_diam_mm_invalid():
MetadataRow.model_validate({"clin_size_long_diam_mm": "foo"})
assert len(excinfo.value.errors()) == 1
assert "Unable to parse value as a number" in convert_errors(excinfo.value)[0]["msg"]


@pytest.mark.parametrize(
("raw", "parsed"),
[
("Benign", ["Benign"]),
("Benign - Other", ["Benign", "Benign - Other"]),
("Blue nevus", ["Benign", "Benign melanocytic proliferations", "Nevus", "Blue nevus"]),
(
"Squamous cell carcinoma, NOS",
["Malignant", "Malignant epidermal proliferations", "Squamous cell carcinoma, NOS"],
),
(
"Blue nevus, Sclerosing",
[
"Benign",
"Benign melanocytic proliferations",
"Nevus",
"Blue nevus",
"Blue nevus, Sclerosing",
],
),
],
)
def test_diagnosis(raw, parsed):
metadata = MetadataRow.model_validate({"diagnosis": raw})

for i, diagnosis in enumerate(parsed, start=1):
assert getattr(metadata, f"diagnosis_{i}") == diagnosis


def test_top_level_diagnosis_is_never_exported():
metadata = MetadataRow.model_validate({"diagnosis": "Benign"})
assert "diagnosis" not in metadata.model_dump()
assert metadata.diagnosis_1 == "Benign"


def test_diagnosis_enum_has_unique_terminal_values():
terminal_nodes = [member.value.split(":")[-1] for member in DiagnosisEnum]
assert len(terminal_nodes) == len(set(terminal_nodes))
87 changes: 87 additions & 0 deletions tests/test_hierarchical_diagnosis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from pydantic import ValidationError
import pytest

from isic_metadata.diagnosis_hierarchical import DiagnosisEnum
from isic_metadata.metadata import MetadataRow


@pytest.mark.parametrize(
("raw", "parsed"),
[
("Benign", ["Benign"]),
("Benign - Other", ["Benign", "Benign - Other"]),
("Blue nevus", ["Benign", "Benign melanocytic proliferations", "Nevus", "Blue nevus"]),
(
"Squamous cell carcinoma, NOS",
["Malignant", "Malignant epidermal proliferations", "Squamous cell carcinoma, NOS"],
),
(
"Blue nevus, Sclerosing",
[
"Benign",
"Benign melanocytic proliferations",
"Nevus",
"Blue nevus",
"Blue nevus, Sclerosing",
],
),
],
)
def test_diagnosis(raw, parsed):
metadata = MetadataRow.model_validate({"diagnosis": raw})

for i, diagnosis in enumerate(parsed, start=1):
assert getattr(metadata, f"diagnosis_{i}") == diagnosis


def test_top_level_diagnosis_is_never_exported():
metadata = MetadataRow.model_validate({"diagnosis": "Benign"})
assert "diagnosis" not in metadata.model_dump()
assert metadata.diagnosis_1 == "Benign"


def test_diagnosis_enum_has_unique_terminal_values():
terminal_nodes = [member.value.split(":")[-1] for member in DiagnosisEnum]
assert len(terminal_nodes) == len(set(terminal_nodes))


def test_single_value_diagnosis_is_favored():
# test that passing in a single diagnosis value is favored over multiple values. used
# for when data is coming from the database and potentially contains an existing
# 1..5 diagnosis and a newly updated single diagnosis.
with pytest.raises(ValidationError) as excinfo:
MetadataRow.model_validate(
{
"diagnosis": "Melanoma Invasive",
"nevus_type": "blue",
# these should be ignored
"diagnosis_1": "Benign",
"diagnosis_2": "Benign melanocytic proliferations",
"diagnosis_3": "Nevus",
}
)
assert "Setting nevus_type is incompatible with diagnosis" in excinfo.value.errors()[0]["msg"]


def test_diagnosis_multiple_levels_is_coerced():
# test that passing in diagnosis_1..5 is coerced into a single diagnosis field to handle
# cross field input validation
metadata = MetadataRow.model_validate({"diagnosis_1": "Benign"})
assert metadata.diagnosis_1 == "Benign"
assert metadata.diagnosis_2 is None
assert metadata.diagnosis_3 is None
assert metadata.diagnosis_4 is None
assert metadata.diagnosis_5 is None


def test_diagnosis_validation_is_idempotent():
# test that running model_validate on a MetadataRow multiple times does not change the
# output
metadata = MetadataRow.model_validate({"diagnosis": "Melanoma Invasive"})
assert metadata.diagnosis_1 == "Malignant"
assert metadata.diagnosis_2 == "Malignant melanocytic proliferations (Melanoma)"
assert metadata.diagnosis_3 == "Melanoma Invasive"
metadata_2 = MetadataRow.model_validate(
metadata.model_dump(exclude_unset=True, exclude_none=True, exclude={"unstructured"})
)
assert metadata == metadata_2

0 comments on commit d742154

Please sign in to comment.