Skip to content

Commit

Permalink
Merge pull request #86 from BrainLesion/dynamic_result
Browse files Browse the repository at this point in the history
Dynamic result
  • Loading branch information
neuronflow authored Jan 24, 2024
2 parents 4ce4a80 + 8dbad1e commit 8b7c47b
Show file tree
Hide file tree
Showing 15 changed files with 938 additions and 448 deletions.
9 changes: 4 additions & 5 deletions examples/example_spine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from auxiliary.turbopath import turbopath

from panoptica import MatchedInstancePair, Panoptic_Evaluator
from panoptica.metrics import Metrics
from panoptica.metrics import Metric, Metric, MetricMode

directory = turbopath(__file__).parent

Expand All @@ -17,16 +17,15 @@

evaluator = Panoptic_Evaluator(
expected_input=MatchedInstancePair,
eval_metrics=[Metrics.ASSD, Metrics.IOU],
decision_metric=Metrics.IOU,
eval_metrics=[Metric.DSC, Metric.IOU],
decision_metric=Metric.DSC,
decision_threshold=0.5,
)


with cProfile.Profile() as pr:
if __name__ == "__main__":
result, debug_data = evaluator.evaluate(sample)

result, debug_data = evaluator.evaluate(sample, verbose=True)
print(result)

pr.dump_stats(directory + "/instance_example.log")
2 changes: 1 addition & 1 deletion examples/example_spine_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Panoptic_Evaluator,
SemanticPair,
)
from panoptica.metrics import Metrics
from panoptica.metrics import Metric

directory = turbopath(__file__).parent

Expand Down
6 changes: 3 additions & 3 deletions panoptica/_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from panoptica.metrics import _compute_instance_iou, _MatchingMetric
from panoptica.metrics import _compute_instance_iou, Metric
from panoptica.utils.constants import CCABackend
from panoptica.utils.numpy_utils import _get_bbox_nd

Expand Down Expand Up @@ -41,7 +41,7 @@ def _calc_matching_metric_of_overlapping_labels(
prediction_arr: np.ndarray,
reference_arr: np.ndarray,
ref_labels: tuple[int, ...],
matching_metric: _MatchingMetric,
matching_metric: Metric,
) -> list[tuple[float, tuple[int, int]]]:
"""Calculates the MatchingMetric for all overlapping labels (fast!)
Expand All @@ -62,7 +62,7 @@ def _calc_matching_metric_of_overlapping_labels(
)
]
with Pool() as pool:
mm_values = pool.starmap(matching_metric._metric_function, instance_pairs)
mm_values = pool.starmap(matching_metric.value._metric_function, instance_pairs)

mm_pairs = [
(i, (instance_pairs[idx][2], instance_pairs[idx][3]))
Expand Down
37 changes: 16 additions & 21 deletions panoptica/instance_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
import concurrent.futures
import gc
from multiprocessing import Pool

import numpy as np

from panoptica.metrics import (
_MatchingMetric,
)
from panoptica.panoptic_result import PanopticaResult
from panoptica.timing import measure_time
from panoptica.utils import EdgeCaseHandler
from panoptica.utils.processing_pair import MatchedInstancePair
from panoptica.metrics import Metrics
from panoptica.metrics import Metric


def evaluate_matched_instance(
matched_instance_pair: MatchedInstancePair,
eval_metrics: list[_MatchingMetric] = [Metrics.DSC, Metrics.IOU, Metrics.ASSD],
decision_metric: _MatchingMetric | None = Metrics.IOU,
eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD],
decision_metric: Metric | None = Metric.IOU,
decision_threshold: float | None = None,
edge_case_handler: EdgeCaseHandler | None = None,
**kwargs,
Expand Down Expand Up @@ -46,9 +39,7 @@ def evaluate_matched_instance(
assert decision_threshold is not None, "decision metric set but no threshold"
# Initialize variables for True Positives (tp)
tp = len(matched_instance_pair.matched_instances)
score_dict: dict[str | _MatchingMetric, list[float]] = {
m.name: [] for m in eval_metrics
}
score_dict: dict[Metric, list[float]] = {m: [] for m in eval_metrics}

reference_arr, prediction_arr = (
matched_instance_pair._reference_arr,
Expand All @@ -61,22 +52,26 @@ def evaluate_matched_instance(
for ref_idx in ref_matched_labels
]
with Pool() as pool:
metric_dicts = pool.starmap(_evaluate_instance, instance_pairs)
metric_dicts: list[dict[Metric, float]] = pool.starmap(
_evaluate_instance, instance_pairs
)

for metric_dict in metric_dicts:
if decision_metric is None or (
decision_threshold is not None
and decision_metric.score_beats_threshold(
metric_dict[decision_metric.name], decision_threshold
metric_dict[decision_metric], decision_threshold
)
):
for k, v in metric_dict.items():
score_dict[k].append(v)

# Create and return the PanopticaResult object with computed metrics
return PanopticaResult(
num_ref_instances=matched_instance_pair.n_reference_instance,
reference_arr=matched_instance_pair.reference_arr,
prediction_arr=matched_instance_pair.prediction_arr,
num_pred_instances=matched_instance_pair.n_prediction_instance,
num_ref_instances=matched_instance_pair.n_reference_instance,
tp=tp,
list_metrics=score_dict,
edge_case_handler=edge_case_handler,
Expand All @@ -87,8 +82,8 @@ def _evaluate_instance(
reference_arr: np.ndarray,
prediction_arr: np.ndarray,
ref_idx: int,
eval_metrics: list[_MatchingMetric],
) -> dict[str, float]:
eval_metrics: list[Metric],
) -> dict[Metric, float]:
"""
Evaluate a single instance.
Expand All @@ -103,12 +98,12 @@ def _evaluate_instance(
"""
ref_arr = reference_arr == ref_idx
pred_arr = prediction_arr == ref_idx
result: dict[str, float] = {}
result: dict[Metric, float] = {}
if ref_arr.sum() == 0 or pred_arr.sum() == 0:
return result
else:
for metric in eval_metrics:
value = metric._metric_function(ref_arr, pred_arr)
result[metric.name] = value
metric_value = metric(ref_arr, pred_arr)
result[metric] = metric_value

return result
6 changes: 3 additions & 3 deletions panoptica/instance_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
_calc_matching_metric_of_overlapping_labels,
_map_labels,
)
from panoptica.metrics import Metrics, _MatchingMetric
from panoptica.metrics import Metric, _Metric
from panoptica.utils.processing_pair import (
InstanceLabelMap,
MatchedInstancePair,
Expand Down Expand Up @@ -153,7 +153,7 @@ class NaiveThresholdMatching(InstanceMatchingAlgorithm):

def __init__(
self,
matching_metric: _MatchingMetric = Metrics.IOU,
matching_metric: Metric = Metric.IOU,
matching_threshold: float = 0.5,
allow_many_to_one: bool = False,
) -> None:
Expand Down Expand Up @@ -228,7 +228,7 @@ class MaximizeMergeMatching(InstanceMatchingAlgorithm):

def __init__(
self,
matching_metric: _MatchingMetric = Metrics.IOU,
matching_metric: Metric = Metric.IOU,
matching_threshold: float = 0.5,
) -> None:
"""
Expand Down
15 changes: 8 additions & 7 deletions panoptica/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
_compute_dice_coefficient,
_compute_instance_volumetric_dice,
)
from panoptica.metrics.iou import _compute_instance_iou, _compute_iou
from panoptica.metrics.metrics import (
Metrics,
ListMetric,
EvalMetric,
MetricDict,
_MatchingMetric,
from panoptica.metrics.iou import (
_compute_instance_iou,
_compute_iou,
)
from panoptica.metrics.cldice import (
_compute_centerline_dice,
_compute_centerline_dice_coefficient,
)
from panoptica.metrics.metrics import Metric, _Metric, MetricMode
57 changes: 57 additions & 0 deletions panoptica/metrics/cldice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from skimage.morphology import skeletonize, skeletonize_3d
import numpy as np


def cl_score(volume: np.ndarray, skeleton: np.ndarray):
"""Computes the skeleton volume overlap
Args:
volume (np.ndarray): volume
skeleton (np.ndarray): skeleton
Returns:
_type_: skeleton overlap
"""
return np.sum(volume * skeleton) / np.sum(skeleton)


def _compute_centerline_dice(
ref_labels: np.ndarray,
pred_labels: np.ndarray,
ref_instance_idx: int,
pred_instance_idx: int,
) -> float:
"""Compute the centerline Dice (clDice) coefficient between a specific pair of instances.
Args:
ref_labels (np.ndarray): Reference instance labels.
pred_labels (np.ndarray): Prediction instance labels.
ref_instance_idx (int): Index of the reference instance.
pred_instance_idx (int): Index of the prediction instance.
Returns:
float: clDice coefficient
"""
ref_instance_mask = ref_labels == ref_instance_idx
pred_instance_mask = pred_labels == pred_instance_idx
return _compute_centerline_dice_coefficient(
reference=ref_instance_mask,
prediction=pred_instance_mask,
)


def _compute_centerline_dice_coefficient(
reference: np.ndarray,
prediction: np.ndarray,
*args,
) -> float:
ndim = reference.ndim
assert 2 <= ndim <= 3, "clDice only implemented for 2D or 3D"
if ndim == 2:
tprec = cl_score(prediction, skeletonize(reference))
tsens = cl_score(reference, skeletonize(prediction))
elif ndim == 3:
tprec = cl_score(prediction, skeletonize_3d(reference))
tsens = cl_score(reference, skeletonize_3d(prediction))

return 2 * tprec * tsens / (tprec + tsens)
Loading

0 comments on commit 8b7c47b

Please sign in to comment.