diff --git a/sample-apps/radiology/lib/infers/deepedit.py b/sample-apps/radiology/lib/infers/deepedit.py index e380d6fa9..cda4af2e5 100644 --- a/sample-apps/radiology/lib/infers/deepedit.py +++ b/sample-apps/radiology/lib/infers/deepedit.py @@ -11,6 +11,7 @@ from typing import Callable, Sequence, Union +from lib.transforms.transforms import OrientationGuidanceMultipleLabelDeepEditd from monai.apps.deepedit.transforms import ( AddGuidanceFromPointsDeepEditd, AddGuidanceSignalDeepEditd, @@ -87,6 +88,7 @@ def pre_transforms(self, data=None): if self.type == InferType.DEEPEDIT: t.extend( [ + OrientationGuidanceMultipleLabelDeepEditd(ref_image="image", label_names=self.labels), AddGuidanceFromPointsDeepEditd(ref_image="image", guidance="guidance", label_names=self.labels), Resized(keys="image", spatial_size=self.spatial_size, mode="area"), ResizeGuidanceMultipleLabelDeepEditd(guidance="guidance", ref_image="image"), diff --git a/sample-apps/radiology/lib/transforms/transforms.py b/sample-apps/radiology/lib/transforms/transforms.py index bc0e6719a..e04471338 100644 --- a/sample-apps/radiology/lib/transforms/transforms.py +++ b/sample-apps/radiology/lib/transforms/transforms.py @@ -14,6 +14,7 @@ import numpy as np import torch +from einops import rearrange from monai.config import KeysCollection, NdarrayOrTensor from monai.transforms import CropForeground, GaussianSmooth, Randomizable, Resize, ScaleIntensity, SpatialCrop from monai.transforms.transform import MapTransform, Transform @@ -505,3 +506,36 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if d.get(cache_key) is None: d[cache_key] = copy.deepcopy(d[key]) return d + + +class OrientationGuidanceMultipleLabelDeepEditd(Transform): + def __init__(self, ref_image: str, label_names=None): + """ + Convert the guidance to the RAS orientation + """ + self.ref_image = ref_image + self.label_names = label_names + + def transform_points(self, point, affine): + """transform point to the coordinates of the transformed image + point: numpy array [bs, N, 3] + """ + bs, N = point.shape[:2] + point = np.concatenate((point, np.ones((bs, N, 1))), axis=-1) + point = rearrange(point, "b n d -> d (b n)") + point = affine @ point + point = rearrange(point, "d (b n)-> b n d", b=bs)[:, :, :3] + return point + + def __call__(self, data): + d: Dict = dict(data) + for key_label in self.label_names.keys(): + points = d.get(key_label, []) + if len(points) < 1: + continue + reoriented_points = self.transform_points( + np.array(points)[None], + np.linalg.inv(d[self.ref_image].meta["affine"].numpy()) @ d[self.ref_image].meta["original_affine"], + ) + d[key_label] = reoriented_points[0] + return d diff --git a/sample-apps/radiology/main.py b/sample-apps/radiology/main.py index f8ed414ec..e34ccbb6a 100644 --- a/sample-apps/radiology/main.py +++ b/sample-apps/radiology/main.py @@ -287,7 +287,7 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("-s", "--studies", default=studies) - parser.add_argument("-m", "--model", default="segmentation_spleen") + parser.add_argument("-m", "--model", default="deepedit") parser.add_argument("-t", "--test", default="infer", choices=("train", "infer")) args = parser.parse_args() @@ -308,7 +308,9 @@ def main(): # Run on all devices for device in device_list(): - res = app.infer(request={"model": args.model, "image": image_id, "device": device}) + res = app.infer( + request={"model": args.model, "image": image_id, "device": device, "spleen": [[6, 6, 6], [9, 9, 9]]} + ) # res = app.infer( # request={"model": "vertebra_pipeline", "image": image_id, "device": device, "slicer": False} # )