Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
push coverage to 100% and small refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
mibaumgartner committed Nov 24, 2019
1 parent c90769b commit a0655dc
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 63 deletions.
11 changes: 3 additions & 8 deletions rising/transforms/functional/intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,7 @@ def add_noise(data: torch.Tensor, noise_type: str, out: torch.Tensor = None,
noise_type = noise_type + '_'
noise_tensor = torch.empty_like(data, requires_grad=False)
getattr(noise_tensor, noise_type)(**kwargs)

if out is None:
return data + noise_tensor
else:
out = data + noise_tensor
return out
return torch.add(data, noise_tensor, out=out)


def gamma_correction(data: torch.Tensor, gamma: float) -> torch.Tensor:
Expand Down Expand Up @@ -219,7 +214,7 @@ def add_value(data: torch.Tensor, value: float, out: torch.Tensor = None) -> tor
torch.Tensor
augmented data
"""
return torch.add(data, value, out)
return torch.add(data, value, out=out)


def scale_by_value(data: torch.Tensor, value: float, out: torch.Tensor = None) -> torch.Tensor:
Expand All @@ -242,4 +237,4 @@ def scale_by_value(data: torch.Tensor, value: float, out: torch.Tensor = None) -
torch.Tensor
augmented data
"""
return torch.mul(data, value, out)
return torch.mul(data, value, out=out)
107 changes: 59 additions & 48 deletions rising/transforms/intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def __init__(self, gamma: Union[float, Sequence] = (0.5, 2),
kwargs:
keyword arguments passed to superclass
"""
super().__init__(augment_fn=gamma_correction, keys=keys, grad=grad, **kwargs)
super().__init__(augment_fn=gamma_correction, keys=keys, grad=grad)
self.kwargs = kwargs
self.gamma = gamma
if not check_scalar(self.gamma):
if not len(self.gamma) == 2:
Expand All @@ -222,26 +223,28 @@ def forward(self, **data) -> dict:
dict
dict with augmented data
"""
for _key in self.keys:
if check_scalar(self.gamma):
_gamma = self.gamma
elif self.gamma[1] < 1:
_gamma = random.uniform(self.gamma[0], self.gamma[1])
if check_scalar(self.gamma):
_gamma = self.gamma
elif self.gamma[1] < 1:
_gamma = random.uniform(self.gamma[0], self.gamma[1])
else:
if random.random() < 0.5:
_gamma = _gamma = random.uniform(self.gamma[0], 1)
else:
if random.random() < 0.5:
_gamma = _gamma = random.uniform(self.gamma[0], 1)
else:
_gamma = _gamma = random.uniform(1, self.gamma[1])
_gamma = _gamma = random.uniform(1, self.gamma[1])

for _key in self.keys:
data[_key] = self.augment_fn(data[_key], _gamma, **self.kwargs)
return data


class RandomAddValue(PerChannelTransform):
def __init__(self, random_mode, random_kwargs: dict = None, per_channel: bool = False,
keys: Sequence = ('data',), grad: bool = False, **kwargs):
class RandomValuePerChannelTransform(PerChannelTransform):
def __init__(self, augment_fn: callable, random_mode: str, random_kwargs: dict = None,
per_channel: bool = False, keys: Sequence = ('data',),
grad: bool = False, **kwargs):
"""
Increase values additively
Apply augmentations which take random values as input by keyword
:param:`value`
Parameters
----------
Expand All @@ -259,27 +262,39 @@ def __init__(self, random_mode, random_kwargs: dict = None, per_channel: bool =
kwargs:
keyword arguments passed to augment_fn
"""
super().__init__(augment_fn=add_value, per_channel=per_channel,
super().__init__(augment_fn=augment_fn, per_channel=per_channel,
keys=keys, grad=grad, **kwargs)
self.random_mode = random_mode
self.random_kwargs = {} if random_kwargs is None else random_kwargs

def forward(self, **data) -> dict:
"""
Apply transformation
Perform Augmentation.
Parameters
----------
data: dict
dict with tensors
dict with data
Returns
-------
dict
dict with augmented data
"""
self.kwargs["value"] = self.random_fn(**self.random_kwargs)
return super().forward(**data)
augmented data
"""
if self.per_channel:
random_seed = random.random()
for _key in self.keys:
random.seed(random_seed)
out = torch.empty_like(data[_key])
for _i in range(data[_key].shape[1]):
rand_value = self.random_fn(**self.random_kwargs)
out[:, _i] = self.augment_fn(data[_key][:, _i], value=rand_value,
out=out[:, _i], **self.kwargs)
data[_key] = out
return data
else:
self.kwargs["value"] = self.random_fn(**self.random_kwargs)
return super().forward(**data)

@property
def random_mode(self) -> str:
Expand Down Expand Up @@ -308,11 +323,11 @@ def random_mode(self, mode) -> None:
self.random_fn = getattr(random, mode)


class RandomScaleValue(PerChannelTransform):
def __init__(self, random_mode, random_kwargs: dict = None, per_channel: bool = False,
class RandomAddValue(RandomValuePerChannelTransform):
def __init__(self, random_mode: str, random_kwargs: dict = None, per_channel: bool = False,
keys: Sequence = ('data',), grad: bool = False, **kwargs):
"""
Scale values
Increase values additively
Parameters
----------
Expand All @@ -330,37 +345,33 @@ def __init__(self, random_mode, random_kwargs: dict = None, per_channel: bool =
kwargs:
keyword arguments passed to augment_fn
"""
super().__init__(augment_fn=scale_by_value, per_channel=per_channel,
super().__init__(augment_fn=add_value, random_mode=random_mode,
random_kwargs=random_kwargs, per_channel=per_channel,
keys=keys, grad=grad, **kwargs)
self.random_mode = random_mode
self.random_kwargs = {} if random_kwargs is None else random_kwargs

def forward(self, **data) -> dict:
self.kwargs["value"] = self.random_fn(**self.random_kwargs)
return super().forward(**data)

@property
def random_mode(self) -> str:
"""
Get random mode

Returns
-------
str
random mode
"""
return self._random_mode

@random_mode.setter
def random_mode(self, mode) -> None:
class RandomScaleValue(RandomValuePerChannelTransform):
def __init__(self, random_mode, random_kwargs: dict = None, per_channel: bool = False,
keys: Sequence = ('data',), grad: bool = False, **kwargs):
"""
Set random mode
Scale values
Parameters
----------
mode: str
random_mode: str
specifies distribution which should be used to sample additive value (supports all
random generators from python random package)
random_kwargs: dict
additional arguments for random function
per_channel: bool
enable transformation per channel
keys: Sequence
keys which should be augmented
grad: bool
enable gradient computation inside transformation
kwargs:
keyword arguments passed to augment_fn
"""
self._random_mode = mode
self.random_fn = getattr(random, mode)
super().__init__(augment_fn=scale_by_value, random_mode=random_mode,
random_kwargs=random_kwargs, per_channel=per_channel,
keys=keys, grad=grad, **kwargs)
15 changes: 8 additions & 7 deletions rising/transforms/kernel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import torch
from typing import Sequence, Union
from typing import Sequence, Union, Callable

from .abstract import AbstractTransform
from rising.utils import check_scalar
Expand Down Expand Up @@ -61,9 +61,10 @@ def __init__(self, in_channels: int, kernel_size: Union[int, Sequence], dim: int
kernel = self.create_kernel()
self.register_buffer('weight', kernel)
self.groups = in_channels
self.set_conv(dim)
self.conv = self.get_conv(dim)

def set_conv(self, dim) -> None:
@staticmethod
def get_conv(dim) -> Callable:
"""
Select convolution with regard to dimension
Expand All @@ -73,13 +74,13 @@ def set_conv(self, dim) -> None:
spatial dimension of data
"""
if dim == 1:
self.conv = torch.nn.functional.conv1d
return torch.nn.functional.conv1d
elif dim == 2:
self.conv = torch.nn.functional.conv2d
return torch.nn.functional.conv2d
elif dim == 3:
self.conv = torch.nn.functional.conv3d
return torch.nn.functional.conv3d
else:
raise RuntimeError('Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim))
raise TypeError('Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim))

def create_kernel(self) -> torch.Tensor:
"""
Expand Down
24 changes: 24 additions & 0 deletions tests/test_transforms/test_abstract_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,30 @@ def augment_fn(inp, *args, **kwargs):
call(torch.tensor([2])), ]
mock.assert_has_calls(calls)

def test_per_channel_transform_per_channel_true(self):
mock = Mock(return_value=0)

def augment_fn(inp, *args, **kwargs):
return mock(inp)

trafo = PerChannelTransform(augment_fn, per_channel=True, keys=('label',))
self.batch_dict["label"] = self.batch_dict["label"][None]
output = trafo(**self.batch_dict)
calls = [call(torch.tensor([0])), call(torch.tensor([1])),
call(torch.tensor([2])), ]
mock.assert_has_calls(calls)

def test_per_channel_transform_per_channel_false(self):
mock = Mock(return_value=0)

def augment_fn(inp, *args, **kwargs):
return mock(inp)

trafo = PerChannelTransform(augment_fn, per_channel=False, keys=('label',))
self.batch_dict["label"] = self.batch_dict["label"][None]
output = trafo(**self.batch_dict)
mock.assert_called_once()

def test_random_dims_transform(self):
torch.manual_seed(0)
self.batch_dict["data"] = torch.rand(1, 1, 32, 16)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_transforms/test_functional/test_intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,18 @@ def test_add_noise(self):
diff = (outp - self.batch_2d).abs().mean()
self.assertTrue(diff > 50)

def test_gamma_correction(self):
outp = gamma_correction(self.batch_2d, 2)
self.assertTrue((self.batch_2d.pow(2) == outp).all())

def test_add_value(self):
outp = add_value(self.batch_2d, 2)
self.assertTrue((torch.add(self.batch_2d, 2) == outp).all())

def test_scale_by_value(self):
outp = scale_by_value(self.batch_2d, 2)
self.assertTrue((torch.mul(self.batch_2d, 2) == outp).all())


if __name__ == '__main__':
unittest.main()
69 changes: 69 additions & 0 deletions tests/test_transforms/test_intensity_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import random
from math import isclose
from unittest.mock import Mock, call

from tests.test_transforms import chech_data_preservation
from rising.transforms.intensity import *
Expand Down Expand Up @@ -91,6 +92,74 @@ def check_noise_distance(self, trafo, min_diff=50):
comp_diff = (outp["data"] - self.batch_dict["data"]).mean().item()
self.assertTrue(comp_diff > min_diff)

def test_per_channel_transform_per_channel_true(self):
mock = Mock(return_value=0)

def augment_fn(inp, *args, **kwargs):
return mock(inp)

trafo = RandomValuePerChannelTransform(
augment_fn, random_mode="random", per_channel=True, keys=('label',))
self.batch_dict["label"] = self.batch_dict["label"][None]
output = trafo(**self.batch_dict)
calls = [call(torch.tensor([0])), call(torch.tensor([1])),
call(torch.tensor([2])), ]
mock.assert_has_calls(calls)

def test_random_add_value(self):
trafo = RandomAddValue("random")
self.assertTrue(chech_data_preservation(trafo, self.batch_dict))

random.seed(0)
rand_val = random.random()
random.seed(0)
outp = trafo(**self.batch_dict)
expected_out = self.batch_dict["data"] + rand_val
self.assertTrue((outp["data"] == expected_out).all())
self.assertEqual(trafo.random_mode, "random")

def test_random_scale_value(self):
trafo = RandomScaleValue("random")
self.assertTrue(chech_data_preservation(trafo, self.batch_dict))

random.seed(0)
rand_val = random.random()
random.seed(0)
outp = trafo(**self.batch_dict)
expected_out = self.batch_dict["data"] * rand_val
self.assertTrue((outp["data"] == expected_out).all())
self.assertEqual(trafo.random_mode, "random")

def test_gamma_transform_scalar(self):
trafo = GammaCorrectionTransform(gamma=2)
self.assertTrue(chech_data_preservation(trafo, self.batch_dict))

trafo = GammaCorrectionTransform(gamma=2)
outp = trafo(**self.batch_dict)
expected_out = self.batch_dict["data"].pow(2)
self.assertTrue((outp["data"] == expected_out).all())

def test_gamma_transform_error(self):
with self.assertRaises(TypeError):
trafo = GammaCorrectionTransform(gamma=(1, 1, 1))

def test_gamma_transform_max_smaller_one(self):
trafo = GammaCorrectionTransform(gamma=(0, 0.9))
random.seed(0)
rand_val = random.uniform(0, 0.9)
random.seed(0)
outp = trafo(**self.batch_dict)
expected_out = self.batch_dict["data"].pow(rand_val)
self.assertTrue((outp["data"] == expected_out).all())

def test_gamma_transform_max_greater_one(self):
# (1., 1.) allows switching cases but forces gamma to 1.
trafo = GammaCorrectionTransform(gamma=(1., 1.))
random.seed(0)
for _ in range(5):
outp = trafo(**self.batch_dict)
self.assertTrue((outp["data"] == self.batch_dict["data"]).all())


if __name__ == '__main__':
unittest.main()
18 changes: 18 additions & 0 deletions tests/test_transforms/test_kernel_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@ def setUp(self) -> None:
"label": torch.arange(3)
}

def test_kernel_transform_get_conv(self):
conv = KernelTransform.get_conv(1)
self.assertEqual(conv, torch.nn.functional.conv1d)

conv = KernelTransform.get_conv(2)
self.assertEqual(conv, torch.nn.functional.conv2d)

conv = KernelTransform.get_conv(3)
self.assertEqual(conv, torch.nn.functional.conv3d)

with self.assertRaises(TypeError):
conv = KernelTransform.get_conv(4)

def test_kernel_transform_error(self):
with self.assertRaises(NotImplementedError):
trafo = KernelTransform(in_channels=1, kernel_size=3, std=1,
dim=2, stride=1, padding=1)

def test_gaussian_smoothing_transform(self):
trafo = GaussianSmoothingTransform(in_channels=1, kernel_size=3, std=1,
dim=2, stride=1, padding=1)
Expand Down

0 comments on commit a0655dc

Please sign in to comment.