Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic set strides kernels in dynunet #955

Merged
69 changes: 51 additions & 18 deletions GANDLF/models/dynunet_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,41 @@
import monai.networks.nets.dynunet as dynunet


def get_kernels_strides(sizes, spacings):
"""
More info: https://github.com/Project-MONAI/tutorials/blob/main/modules/dynunet_pipeline/create_network.py#L19

When refering this method for other tasks, please ensure that the patch size for each spatial dimension should
be divisible by the product of all strides in the corresponding dimension.
In addition, the minimal spatial size should have at least one dimension that has twice the size of
the product of all strides. For patch sizes that cannot find suitable strides, an error will be raised.

"""
input_size = sizes
strides, kernels = [], []
while True:
spacing_ratio = [sp / min(spacings) for sp in spacings]
stride = [
2 if ratio <= 2 and size >= 8 else 1
for (ratio, size) in zip(spacing_ratio, sizes)
]
kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
if all(s == 1 for s in stride):
break
for idx, (i, j) in enumerate(zip(sizes, stride)):
assert (
i % j == 0
), f"Patch size is not supported, please try to modify the size {input_size[idx]} in the spatial dimension {idx}."
sizes = [i / j for i, j in zip(sizes, stride)]
spacings = [i * j for i, j in zip(spacings, stride)]
kernels.append(kernel)
strides.append(stride)

strides.insert(0, len(spacings) * [1])
kernels.append(len(spacings) * [3])
return kernels, strides


class dynunet_wrapper(ModelBase):
"""
More info: https://docs.monai.io/en/stable/networks.html#dynunet
Expand All @@ -26,35 +61,33 @@ class dynunet_wrapper(ModelBase):
def __init__(self, parameters: dict):
super(dynunet_wrapper, self).__init__(parameters)

# checking for validation
assert (
"kernel_size" in parameters["model"]
) == True, "\033[0;31m`kernel_size` key missing in parameters"
assert (
"strides" in parameters["model"]
) == True, "\033[0;31m`strides` key missing in parameters"

# defining some defaults
# if not ("upsample_kernel_size" in parameters["model"]):
# parameters["model"]["upsample_kernel_size"] = parameters["model"][
# "strides"
# ][1:]
patch_size = parameters.get("patch_size", None)
spacing = parameters.get(
"spacing_for_internal_computations",
[1.0 for i in range(parameters["model"]["dimension"])],
)
parameters["model"]["kernel_size"] = parameters["model"].get(
"kernel_size", None
)
parameters["model"]["strides"] = parameters["model"].get("strides", None)
if (parameters["model"]["kernel_size"] is None) or (
parameters["model"]["strides"] is None
):
kernel_size, strides = get_kernels_strides(patch_size, spacing)
parameters["model"]["kernel_size"] = kernel_size
parameters["model"]["strides"] = strides

parameters["model"]["filters"] = parameters["model"].get("filters", None)
parameters["model"]["act_name"] = parameters["model"].get(
"act_name", ("leakyrelu", {"inplace": True, "negative_slope": 0.01})
)

parameters["model"]["deep_supervision"] = parameters["model"].get(
"deep_supervision", True
"deep_supervision", False
)

parameters["model"]["deep_supr_num"] = parameters["model"].get(
"deep_supr_num", 1
)

parameters["model"]["res_block"] = parameters["model"].get("res_block", True)

parameters["model"]["trans_bias"] = parameters["model"].get("trans_bias", False)
parameters["model"]["dropout"] = parameters["model"].get("dropout", None)

Expand Down
12 changes: 0 additions & 12 deletions testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,6 @@ def test_train_segmentation_rad_2d(device):
["acs", "soft", "conv3d"]
)

if model == "dynunet":
# More info: https://github.com/Project-MONAI/MONAI/blob/96bfda00c6bd290297f5e3514ea227c6be4d08b4/tests/test_dynunet.py
parameters["model"]["kernel_size"] = (3, 3, 3, 1)
parameters["model"]["strides"] = (1, 1, 1, 1)
parameters["model"]["deep_supervision"] = False

parameters["model"]["architecture"] = model
parameters["nested_training"]["testing"] = -5
parameters["nested_training"]["validation"] = -5
Expand Down Expand Up @@ -374,12 +368,6 @@ def test_train_segmentation_rad_3d(device):
["acs", "soft", "conv3d"]
)

if model == "dynunet":
# More info: https://github.com/Project-MONAI/MONAI/blob/96bfda00c6bd290297f5e3514ea227c6be4d08b4/tests/test_dynunet.py
parameters["model"]["kernel_size"] = (3, 3, 3, 1)
parameters["model"]["strides"] = (1, 1, 1, 1)
parameters["model"]["deep_supervision"] = False

parameters["model"]["architecture"] = model
parameters["nested_training"]["testing"] = -5
parameters["nested_training"]["validation"] = -5
Expand Down
Loading