diff --git a/rising/transforms/functional/intensity.py b/rising/transforms/functional/intensity.py index 9ffbc24d..30081bc6 100644 --- a/rising/transforms/functional/intensity.py +++ b/rising/transforms/functional/intensity.py @@ -171,12 +171,7 @@ def add_noise(data: torch.Tensor, noise_type: str, out: torch.Tensor = None, noise_type = noise_type + '_' noise_tensor = torch.empty_like(data, requires_grad=False) getattr(noise_tensor, noise_type)(**kwargs) - - if out is None: - return data + noise_tensor - else: - out = data + noise_tensor - return out + return torch.add(data, noise_tensor, out=out) def gamma_correction(data: torch.Tensor, gamma: float) -> torch.Tensor: @@ -219,7 +214,7 @@ def add_value(data: torch.Tensor, value: float, out: torch.Tensor = None) -> tor torch.Tensor augmented data """ - return torch.add(data, value, out) + return torch.add(data, value, out=out) def scale_by_value(data: torch.Tensor, value: float, out: torch.Tensor = None) -> torch.Tensor: @@ -242,4 +237,4 @@ def scale_by_value(data: torch.Tensor, value: float, out: torch.Tensor = None) - torch.Tensor augmented data """ - return torch.mul(data, value, out) + return torch.mul(data, value, out=out) diff --git a/rising/transforms/intensity.py b/rising/transforms/intensity.py index 8c59eda1..534cba59 100644 --- a/rising/transforms/intensity.py +++ b/rising/transforms/intensity.py @@ -201,7 +201,8 @@ def __init__(self, gamma: Union[float, Sequence] = (0.5, 2), kwargs: keyword arguments passed to superclass """ - super().__init__(augment_fn=gamma_correction, keys=keys, grad=grad, **kwargs) + super().__init__(augment_fn=gamma_correction, keys=keys, grad=grad) + self.kwargs = kwargs self.gamma = gamma if not check_scalar(self.gamma): if not len(self.gamma) == 2: @@ -222,26 +223,28 @@ def forward(self, **data) -> dict: dict dict with augmented data """ - for _key in self.keys: - if check_scalar(self.gamma): - _gamma = self.gamma - elif self.gamma[1] < 1: - _gamma = random.uniform(self.gamma[0], self.gamma[1]) + if check_scalar(self.gamma): + _gamma = self.gamma + elif self.gamma[1] < 1: + _gamma = random.uniform(self.gamma[0], self.gamma[1]) + else: + if random.random() < 0.5: + _gamma = _gamma = random.uniform(self.gamma[0], 1) else: - if random.random() < 0.5: - _gamma = _gamma = random.uniform(self.gamma[0], 1) - else: - _gamma = _gamma = random.uniform(1, self.gamma[1]) + _gamma = _gamma = random.uniform(1, self.gamma[1]) + for _key in self.keys: data[_key] = self.augment_fn(data[_key], _gamma, **self.kwargs) return data -class RandomAddValue(PerChannelTransform): - def __init__(self, random_mode, random_kwargs: dict = None, per_channel: bool = False, - keys: Sequence = ('data',), grad: bool = False, **kwargs): +class RandomValuePerChannelTransform(PerChannelTransform): + def __init__(self, augment_fn: callable, random_mode: str, random_kwargs: dict = None, + per_channel: bool = False, keys: Sequence = ('data',), + grad: bool = False, **kwargs): """ - Increase values additively + Apply augmentations which take random values as input by keyword + :param:`value` Parameters ---------- @@ -259,27 +262,39 @@ def __init__(self, random_mode, random_kwargs: dict = None, per_channel: bool = kwargs: keyword arguments passed to augment_fn """ - super().__init__(augment_fn=add_value, per_channel=per_channel, + super().__init__(augment_fn=augment_fn, per_channel=per_channel, keys=keys, grad=grad, **kwargs) self.random_mode = random_mode self.random_kwargs = {} if random_kwargs is None else random_kwargs def forward(self, **data) -> dict: """ - Apply transformation + Perform Augmentation. Parameters ---------- data: dict - dict with tensors + dict with data Returns ------- dict - dict with augmented data - """ - self.kwargs["value"] = self.random_fn(**self.random_kwargs) - return super().forward(**data) + augmented data + """ + if self.per_channel: + random_seed = random.random() + for _key in self.keys: + random.seed(random_seed) + out = torch.empty_like(data[_key]) + for _i in range(data[_key].shape[1]): + rand_value = self.random_fn(**self.random_kwargs) + out[:, _i] = self.augment_fn(data[_key][:, _i], value=rand_value, + out=out[:, _i], **self.kwargs) + data[_key] = out + return data + else: + self.kwargs["value"] = self.random_fn(**self.random_kwargs) + return super().forward(**data) @property def random_mode(self) -> str: @@ -308,11 +323,11 @@ def random_mode(self, mode) -> None: self.random_fn = getattr(random, mode) -class RandomScaleValue(PerChannelTransform): - def __init__(self, random_mode, random_kwargs: dict = None, per_channel: bool = False, +class RandomAddValue(RandomValuePerChannelTransform): + def __init__(self, random_mode: str, random_kwargs: dict = None, per_channel: bool = False, keys: Sequence = ('data',), grad: bool = False, **kwargs): """ - Scale values + Increase values additively Parameters ---------- @@ -330,37 +345,33 @@ def __init__(self, random_mode, random_kwargs: dict = None, per_channel: bool = kwargs: keyword arguments passed to augment_fn """ - super().__init__(augment_fn=scale_by_value, per_channel=per_channel, + super().__init__(augment_fn=add_value, random_mode=random_mode, + random_kwargs=random_kwargs, per_channel=per_channel, keys=keys, grad=grad, **kwargs) - self.random_mode = random_mode - self.random_kwargs = {} if random_kwargs is None else random_kwargs - - def forward(self, **data) -> dict: - self.kwargs["value"] = self.random_fn(**self.random_kwargs) - return super().forward(**data) - - @property - def random_mode(self) -> str: - """ - Get random mode - Returns - ------- - str - random mode - """ - return self._random_mode - @random_mode.setter - def random_mode(self, mode) -> None: +class RandomScaleValue(RandomValuePerChannelTransform): + def __init__(self, random_mode, random_kwargs: dict = None, per_channel: bool = False, + keys: Sequence = ('data',), grad: bool = False, **kwargs): """ - Set random mode + Scale values Parameters ---------- - mode: str + random_mode: str specifies distribution which should be used to sample additive value (supports all random generators from python random package) + random_kwargs: dict + additional arguments for random function + per_channel: bool + enable transformation per channel + keys: Sequence + keys which should be augmented + grad: bool + enable gradient computation inside transformation + kwargs: + keyword arguments passed to augment_fn """ - self._random_mode = mode - self.random_fn = getattr(random, mode) \ No newline at end of file + super().__init__(augment_fn=scale_by_value, random_mode=random_mode, + random_kwargs=random_kwargs, per_channel=per_channel, + keys=keys, grad=grad, **kwargs) diff --git a/rising/transforms/kernel.py b/rising/transforms/kernel.py index 2756ec96..1b912611 100644 --- a/rising/transforms/kernel.py +++ b/rising/transforms/kernel.py @@ -1,6 +1,6 @@ import math import torch -from typing import Sequence, Union +from typing import Sequence, Union, Callable from .abstract import AbstractTransform from rising.utils import check_scalar @@ -61,9 +61,10 @@ def __init__(self, in_channels: int, kernel_size: Union[int, Sequence], dim: int kernel = self.create_kernel() self.register_buffer('weight', kernel) self.groups = in_channels - self.set_conv(dim) + self.conv = self.get_conv(dim) - def set_conv(self, dim) -> None: + @staticmethod + def get_conv(dim) -> Callable: """ Select convolution with regard to dimension @@ -73,13 +74,13 @@ def set_conv(self, dim) -> None: spatial dimension of data """ if dim == 1: - self.conv = torch.nn.functional.conv1d + return torch.nn.functional.conv1d elif dim == 2: - self.conv = torch.nn.functional.conv2d + return torch.nn.functional.conv2d elif dim == 3: - self.conv = torch.nn.functional.conv3d + return torch.nn.functional.conv3d else: - raise RuntimeError('Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)) + raise TypeError('Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)) def create_kernel(self) -> torch.Tensor: """ diff --git a/tests/test_transforms/test_abstract_transform.py b/tests/test_transforms/test_abstract_transform.py index 7cf02061..137475b9 100644 --- a/tests/test_transforms/test_abstract_transform.py +++ b/tests/test_transforms/test_abstract_transform.py @@ -78,6 +78,30 @@ def augment_fn(inp, *args, **kwargs): call(torch.tensor([2])), ] mock.assert_has_calls(calls) + def test_per_channel_transform_per_channel_true(self): + mock = Mock(return_value=0) + + def augment_fn(inp, *args, **kwargs): + return mock(inp) + + trafo = PerChannelTransform(augment_fn, per_channel=True, keys=('label',)) + self.batch_dict["label"] = self.batch_dict["label"][None] + output = trafo(**self.batch_dict) + calls = [call(torch.tensor([0])), call(torch.tensor([1])), + call(torch.tensor([2])), ] + mock.assert_has_calls(calls) + + def test_per_channel_transform_per_channel_false(self): + mock = Mock(return_value=0) + + def augment_fn(inp, *args, **kwargs): + return mock(inp) + + trafo = PerChannelTransform(augment_fn, per_channel=False, keys=('label',)) + self.batch_dict["label"] = self.batch_dict["label"][None] + output = trafo(**self.batch_dict) + mock.assert_called_once() + def test_random_dims_transform(self): torch.manual_seed(0) self.batch_dict["data"] = torch.rand(1, 1, 32, 16) diff --git a/tests/test_transforms/test_functional/test_intensity.py b/tests/test_transforms/test_functional/test_intensity.py index e8ce88f3..7ef5a48b 100644 --- a/tests/test_transforms/test_functional/test_intensity.py +++ b/tests/test_transforms/test_functional/test_intensity.py @@ -80,6 +80,18 @@ def test_add_noise(self): diff = (outp - self.batch_2d).abs().mean() self.assertTrue(diff > 50) + def test_gamma_correction(self): + outp = gamma_correction(self.batch_2d, 2) + self.assertTrue((self.batch_2d.pow(2) == outp).all()) + + def test_add_value(self): + outp = add_value(self.batch_2d, 2) + self.assertTrue((torch.add(self.batch_2d, 2) == outp).all()) + + def test_scale_by_value(self): + outp = scale_by_value(self.batch_2d, 2) + self.assertTrue((torch.mul(self.batch_2d, 2) == outp).all()) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_transforms/test_intensity_transforms.py b/tests/test_transforms/test_intensity_transforms.py index cdd939dd..66e5026d 100644 --- a/tests/test_transforms/test_intensity_transforms.py +++ b/tests/test_transforms/test_intensity_transforms.py @@ -2,6 +2,7 @@ import torch import random from math import isclose +from unittest.mock import Mock, call from tests.test_transforms import chech_data_preservation from rising.transforms.intensity import * @@ -91,6 +92,74 @@ def check_noise_distance(self, trafo, min_diff=50): comp_diff = (outp["data"] - self.batch_dict["data"]).mean().item() self.assertTrue(comp_diff > min_diff) + def test_per_channel_transform_per_channel_true(self): + mock = Mock(return_value=0) + + def augment_fn(inp, *args, **kwargs): + return mock(inp) + + trafo = RandomValuePerChannelTransform( + augment_fn, random_mode="random", per_channel=True, keys=('label',)) + self.batch_dict["label"] = self.batch_dict["label"][None] + output = trafo(**self.batch_dict) + calls = [call(torch.tensor([0])), call(torch.tensor([1])), + call(torch.tensor([2])), ] + mock.assert_has_calls(calls) + + def test_random_add_value(self): + trafo = RandomAddValue("random") + self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) + + random.seed(0) + rand_val = random.random() + random.seed(0) + outp = trafo(**self.batch_dict) + expected_out = self.batch_dict["data"] + rand_val + self.assertTrue((outp["data"] == expected_out).all()) + self.assertEqual(trafo.random_mode, "random") + + def test_random_scale_value(self): + trafo = RandomScaleValue("random") + self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) + + random.seed(0) + rand_val = random.random() + random.seed(0) + outp = trafo(**self.batch_dict) + expected_out = self.batch_dict["data"] * rand_val + self.assertTrue((outp["data"] == expected_out).all()) + self.assertEqual(trafo.random_mode, "random") + + def test_gamma_transform_scalar(self): + trafo = GammaCorrectionTransform(gamma=2) + self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) + + trafo = GammaCorrectionTransform(gamma=2) + outp = trafo(**self.batch_dict) + expected_out = self.batch_dict["data"].pow(2) + self.assertTrue((outp["data"] == expected_out).all()) + + def test_gamma_transform_error(self): + with self.assertRaises(TypeError): + trafo = GammaCorrectionTransform(gamma=(1, 1, 1)) + + def test_gamma_transform_max_smaller_one(self): + trafo = GammaCorrectionTransform(gamma=(0, 0.9)) + random.seed(0) + rand_val = random.uniform(0, 0.9) + random.seed(0) + outp = trafo(**self.batch_dict) + expected_out = self.batch_dict["data"].pow(rand_val) + self.assertTrue((outp["data"] == expected_out).all()) + + def test_gamma_transform_max_greater_one(self): + # (1., 1.) allows switching cases but forces gamma to 1. + trafo = GammaCorrectionTransform(gamma=(1., 1.)) + random.seed(0) + for _ in range(5): + outp = trafo(**self.batch_dict) + self.assertTrue((outp["data"] == self.batch_dict["data"]).all()) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_transforms/test_kernel_transforms.py b/tests/test_transforms/test_kernel_transforms.py index ebf40b73..623730d8 100644 --- a/tests/test_transforms/test_kernel_transforms.py +++ b/tests/test_transforms/test_kernel_transforms.py @@ -18,6 +18,24 @@ def setUp(self) -> None: "label": torch.arange(3) } + def test_kernel_transform_get_conv(self): + conv = KernelTransform.get_conv(1) + self.assertEqual(conv, torch.nn.functional.conv1d) + + conv = KernelTransform.get_conv(2) + self.assertEqual(conv, torch.nn.functional.conv2d) + + conv = KernelTransform.get_conv(3) + self.assertEqual(conv, torch.nn.functional.conv3d) + + with self.assertRaises(TypeError): + conv = KernelTransform.get_conv(4) + + def test_kernel_transform_error(self): + with self.assertRaises(NotImplementedError): + trafo = KernelTransform(in_channels=1, kernel_size=3, std=1, + dim=2, stride=1, padding=1) + def test_gaussian_smoothing_transform(self): trafo = GaussianSmoothingTransform(in_channels=1, kernel_size=3, std=1, dim=2, stride=1, padding=1)