From e3769d1e9233fddf278e33dbfe0eb464cc62ba3a Mon Sep 17 00:00:00 2001 From: Sanchit Verma <23355869+sanchitv7@users.noreply.github.com> Date: Tue, 13 Aug 2024 20:22:13 +0100 Subject: [PATCH] made the fix for the error for issue #251 --- pytorchvideo/transforms/augmentations.py | 260 ++++------------------- 1 file changed, 45 insertions(+), 215 deletions(-) diff --git a/pytorchvideo/transforms/augmentations.py b/pytorchvideo/transforms/augmentations.py index 403eb0a..3ced8e1 100644 --- a/pytorchvideo/transforms/augmentations.py +++ b/pytorchvideo/transforms/augmentations.py @@ -1,14 +1,12 @@ -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Video transforms that are used for advanced augmentation methods.""" -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - import torch import torchvision -from torch import Tensor +import torchvision.transforms.functional as F_t from torchvision.transforms.functional import InterpolationMode - +from typing import Any, Callable, Dict, Optional, Tuple # Maximum global magnitude used for video augmentation. _AUGMENTATION_MAX_LEVEL = 10 @@ -67,7 +65,7 @@ def _rotate(video: torch.Tensor, factor: float, **kwargs) -> torch.Tensor: """ _check_fill_arg(kwargs) return torchvision.transforms.functional.rotate( - video, factor, fill=kwargs["fill"], interpolation=InterpolationMode.BILINEAR + video, factor, fill=kwargs["fill"], interpolation=F_t.InterpolationMode.BILINEAR ) @@ -166,11 +164,14 @@ def _shear_x(video: torch.Tensor, factor: float, **kwargs): """ _check_fill_arg(kwargs) translation_offset = video.size(-2) * factor / 2 - return affine( + return F_t.affine( video, - [1, factor, translation_offset, 0, 1, 0], + angle=0, + translate=[translation_offset, 0], + scale=1, + shear=[factor, 1], fill=kwargs["fill"], - interpolation="bilinear", + interpolation=F_t.InterpolationMode.BILINEAR, ) @@ -185,11 +186,14 @@ def _shear_y(video: torch.Tensor, factor: float, **kwargs): """ _check_fill_arg(kwargs) translation_offset = video.size(-1) * factor / 2 - return affine( + return F_t.affine( video, - [1, 0, 0, factor, 1, translation_offset], - fill=kwargs["fill"], - interpolation="bilinear", + angle=0, + translate=[0, translation_offset], + scale=1, + shear=[1, factor], + interpolation=F_t.InterpolationMode.BILINEAR, + fill=kwargs["fill"] ) @@ -204,11 +208,14 @@ def _translate_x(video: torch.Tensor, factor: float, **kwargs): """ _check_fill_arg(kwargs) translation_offset = factor * video.size(-1) - return affine( + return F_t.affine( video, - [1, 0, translation_offset, 0, 1, 0], - fill=kwargs["fill"], - interpolation="bilinear", + angle=0, + translate=[translation_offset, 0], + scale=1, + shear=[1, 1], + interpolation=F_t.InterpolationMode.BILINEAR, + fill=kwargs["fill"] ) @@ -223,11 +230,14 @@ def _translate_y(video: torch.Tensor, factor: float, **kwargs): """ _check_fill_arg(kwargs) translation_offset = factor * video.size(-2) - return affine( + return F_t.affine( video, - [1, 0, 0, 0, 1, translation_offset], - fill=kwargs["fill"], - interpolation="bilinear", + angle=0, + translate=[0, translation_offset], + scale=1, + shear=[1, 1], + interpolation=F_t.InterpolationMode.BILINEAR, + fill=kwargs["fill"] ) @@ -257,7 +267,7 @@ def _increasing_magnitude_to_arg(level: int, params: Tuple[float, float]) -> flo def _increasing_randomly_negate_to_arg( - level: int, params: Tuple[float, float] + level: int, params: Tuple[float, float] ) -> Tuple[float]: """ Convert level to transform magnitude. This assumes transform magnitude increases @@ -369,16 +379,16 @@ def _decreasing_to_arg(level: int, params: Tuple[float, float]) -> Tuple[float]: class AugmentTransform: def __init__( - self, - transform_name: str, - magnitude: int = 10, - prob: float = 0.5, - name_to_transform_func: Optional[Dict[str, Callable]] = None, - level_to_arg: Optional[Dict[str, Callable]] = None, - transform_max_paras: Optional[Dict[str, Tuple]] = None, - transform_hparas: Optional[Dict[str, Any]] = None, - sampling_type: str = "gaussian", - sampling_hparas: Optional[Dict[str, Any]] = None, + self, + transform_name: str, + magnitude: int = 10, + prob: float = 0.5, + name_to_transform_func: Optional[Dict[str, Callable]] = None, + level_to_arg: Optional[Dict[str, Callable]] = None, + transform_max_paras: Optional[Dict[str, Tuple]] = None, + transform_hparas: Optional[Dict[str, Any]] = None, + sampling_type: str = "gaussian", + sampling_hparas: Optional[Dict[str, Any]] = None, ) -> None: """ The AugmentTransform composes a video transform that performs augmentation @@ -455,9 +465,9 @@ def _get_magnitude(self) -> float: ).item() elif self.sampling_hparas["sampling_data_type"] == "float": return ( - torch.rand(size=(1,)).item() - * (self.magnitude - self.sampling_hparas["sampling_min"]) - + self.sampling_hparas["sampling_min"] + torch.rand(size=(1,)).item() + * (self.magnitude - self.sampling_hparas["sampling_min"]) + + self.sampling_hparas["sampling_min"] ) else: raise ValueError("sampling_data_type must be either 'int' or 'float'") @@ -480,183 +490,3 @@ def __call__(self, video: torch.Tensor) -> torch.Tensor: else () ) return self.transform_fn(video, *level_args, **self.transform_hparas) - - -def _assert_grid_transform_inputs( - img: Tensor, - matrix: Optional[List[float]], - interpolation: str, - fill: Optional[Union[int, float, List[float]]], - supported_interpolation_modes: List[str], - coeffs: Optional[List[float]] = None, -) -> None: - def get_dimensions(img: Tensor) -> List[int]: - channels = 1 if img.ndim == 2 else img.shape[-3] - height, width = img.shape[-2:] - return [channels, height, width] - - if not (isinstance(img, torch.Tensor)): - raise TypeError("Input img should be Tensor") - - if matrix is not None and not isinstance(matrix, list): - raise TypeError("Argument matrix should be a list") - - if matrix is not None and len(matrix) != 6: - raise ValueError("Argument matrix should have 6 float values") - - if coeffs is not None and len(coeffs) != 8: - raise ValueError("Argument coeffs should have 8 float values") - - # Check fill - num_channels = get_dimensions(img)[0] - if ( - fill is not None - and isinstance(fill, (tuple, list)) - and len(fill) > 1 - and len(fill) != num_channels - ): - msg = ( - "The number of elements in 'fill' cannot broadcast to match the number of " - "channels of the image ({} != {})" - ) - raise ValueError(msg.format(len(fill), num_channels)) - - if interpolation not in supported_interpolation_modes: - raise ValueError( - f"Interpolation mode '{interpolation}' is unsupported with Tensor input" - ) - - -def _cast_squeeze_in( - img: Tensor, req_dtypes: List[torch.dtype] -) -> Tuple[Tensor, bool, bool, torch.dtype]: - need_squeeze = False - # make image NCHW - if img.ndim < 4: - img = img.unsqueeze(dim=0) - need_squeeze = True - - out_dtype = img.dtype - need_cast = False - if out_dtype not in req_dtypes: - need_cast = True - req_dtype = req_dtypes[0] - img = img.to(req_dtype) - return img, need_cast, need_squeeze, out_dtype - - -def _cast_squeeze_out( - img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype -) -> Tensor: - if need_squeeze: - img = img.squeeze(dim=0) - - if need_cast: - if out_dtype in ( - torch.uint8, - torch.int8, - torch.int16, - torch.int32, - torch.int64, - ): - # it is better to round before cast - img = torch.round(img) - img = img.to(out_dtype) - - return img - - -def _apply_grid_transform( - img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]] -) -> Tensor: - - img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype]) - - if img.shape[0] > 1: - # Apply same grid to a batch of images - grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3]) - - # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice - if fill is not None: - mask = torch.ones( - (img.shape[0], 1, img.shape[2], img.shape[3]), - dtype=img.dtype, - device=img.device, - ) - img = torch.cat((img, mask), dim=1) - - img = torch.nn.functional.grid_sample( - img, grid, mode=mode, padding_mode="zeros", align_corners=False - ) - - # Fill with required color - if fill is not None: - mask = img[:, -1:, :, :] # N * 1 * H * W - img = img[:, :-1, :, :] # N * C * H * W - mask = mask.expand_as(img) - fill_list, len_fill = ( - (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1) - ) - fill_img = ( - torch.tensor(fill_list, dtype=img.dtype, device=img.device) - .view(1, len_fill, 1, 1) - .expand_as(img) - ) - if mode == "nearest": - mask = mask < 0.5 - img[mask] = fill_img[mask] - else: # 'bilinear' - img = img * mask + (1.0 - mask) * fill_img - - img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) - return img - - -def _gen_affine_grid( - theta: Tensor, - w: int, - h: int, - ow: int, - oh: int, -) -> Tensor: - # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/ - # AffineGridGenerator.cpp#L18 - # Difference with AffineGridGenerator is that: - # 1) we normalize grid values after applying theta - # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate - - d = 0.5 - base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device) - x_grid = torch.linspace( - -ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device - ) - base_grid[..., 0].copy_(x_grid) - y_grid = torch.linspace( - -oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device - ).unsqueeze_(-1) - base_grid[..., 1].copy_(y_grid) - base_grid[..., 2].fill_(1) - - rescaled_theta = theta.transpose(1, 2) / torch.tensor( - [0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device - ) - output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta) - return output_grid.view(1, oh, ow, 2) - - -def affine( - img: Tensor, - matrix: List[float], - interpolation: str = "nearest", - fill: Optional[Union[int, float, List[float]]] = None, -) -> Tensor: - _assert_grid_transform_inputs( - img, matrix, interpolation, fill, ["nearest", "bilinear"] - ) - - dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) - shape = img.shape - # grid will be generated on the same device as theta and img - grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) - return _apply_grid_transform(img, grid, interpolation, fill=fill)