Skip to content

Commit

Permalink
made the fix for the error for issue facebookresearch#251
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchitv7 committed Aug 13, 2024
1 parent e02c17b commit e3769d1
Showing 1 changed file with 45 additions and 215 deletions.
260 changes: 45 additions & 215 deletions pytorchvideo/transforms/augmentations.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

"""Video transforms that are used for advanced augmentation methods."""

from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torchvision
from torch import Tensor
import torchvision.transforms.functional as F_t
from torchvision.transforms.functional import InterpolationMode

from typing import Any, Callable, Dict, Optional, Tuple

# Maximum global magnitude used for video augmentation.
_AUGMENTATION_MAX_LEVEL = 10
Expand Down Expand Up @@ -67,7 +65,7 @@ def _rotate(video: torch.Tensor, factor: float, **kwargs) -> torch.Tensor:
"""
_check_fill_arg(kwargs)
return torchvision.transforms.functional.rotate(
video, factor, fill=kwargs["fill"], interpolation=InterpolationMode.BILINEAR
video, factor, fill=kwargs["fill"], interpolation=F_t.InterpolationMode.BILINEAR
)


Expand Down Expand Up @@ -166,11 +164,14 @@ def _shear_x(video: torch.Tensor, factor: float, **kwargs):
"""
_check_fill_arg(kwargs)
translation_offset = video.size(-2) * factor / 2
return affine(
return F_t.affine(
video,
[1, factor, translation_offset, 0, 1, 0],
angle=0,
translate=[translation_offset, 0],
scale=1,
shear=[factor, 1],
fill=kwargs["fill"],
interpolation="bilinear",
interpolation=F_t.InterpolationMode.BILINEAR,
)


Expand All @@ -185,11 +186,14 @@ def _shear_y(video: torch.Tensor, factor: float, **kwargs):
"""
_check_fill_arg(kwargs)
translation_offset = video.size(-1) * factor / 2
return affine(
return F_t.affine(
video,
[1, 0, 0, factor, 1, translation_offset],
fill=kwargs["fill"],
interpolation="bilinear",
angle=0,
translate=[0, translation_offset],
scale=1,
shear=[1, factor],
interpolation=F_t.InterpolationMode.BILINEAR,
fill=kwargs["fill"]
)


Expand All @@ -204,11 +208,14 @@ def _translate_x(video: torch.Tensor, factor: float, **kwargs):
"""
_check_fill_arg(kwargs)
translation_offset = factor * video.size(-1)
return affine(
return F_t.affine(
video,
[1, 0, translation_offset, 0, 1, 0],
fill=kwargs["fill"],
interpolation="bilinear",
angle=0,
translate=[translation_offset, 0],
scale=1,
shear=[1, 1],
interpolation=F_t.InterpolationMode.BILINEAR,
fill=kwargs["fill"]
)


Expand All @@ -223,11 +230,14 @@ def _translate_y(video: torch.Tensor, factor: float, **kwargs):
"""
_check_fill_arg(kwargs)
translation_offset = factor * video.size(-2)
return affine(
return F_t.affine(
video,
[1, 0, 0, 0, 1, translation_offset],
fill=kwargs["fill"],
interpolation="bilinear",
angle=0,
translate=[0, translation_offset],
scale=1,
shear=[1, 1],
interpolation=F_t.InterpolationMode.BILINEAR,
fill=kwargs["fill"]
)


Expand Down Expand Up @@ -257,7 +267,7 @@ def _increasing_magnitude_to_arg(level: int, params: Tuple[float, float]) -> flo


def _increasing_randomly_negate_to_arg(
level: int, params: Tuple[float, float]
level: int, params: Tuple[float, float]
) -> Tuple[float]:
"""
Convert level to transform magnitude. This assumes transform magnitude increases
Expand Down Expand Up @@ -369,16 +379,16 @@ def _decreasing_to_arg(level: int, params: Tuple[float, float]) -> Tuple[float]:

class AugmentTransform:
def __init__(
self,
transform_name: str,
magnitude: int = 10,
prob: float = 0.5,
name_to_transform_func: Optional[Dict[str, Callable]] = None,
level_to_arg: Optional[Dict[str, Callable]] = None,
transform_max_paras: Optional[Dict[str, Tuple]] = None,
transform_hparas: Optional[Dict[str, Any]] = None,
sampling_type: str = "gaussian",
sampling_hparas: Optional[Dict[str, Any]] = None,
self,
transform_name: str,
magnitude: int = 10,
prob: float = 0.5,
name_to_transform_func: Optional[Dict[str, Callable]] = None,
level_to_arg: Optional[Dict[str, Callable]] = None,
transform_max_paras: Optional[Dict[str, Tuple]] = None,
transform_hparas: Optional[Dict[str, Any]] = None,
sampling_type: str = "gaussian",
sampling_hparas: Optional[Dict[str, Any]] = None,
) -> None:
"""
The AugmentTransform composes a video transform that performs augmentation
Expand Down Expand Up @@ -455,9 +465,9 @@ def _get_magnitude(self) -> float:
).item()
elif self.sampling_hparas["sampling_data_type"] == "float":
return (
torch.rand(size=(1,)).item()
* (self.magnitude - self.sampling_hparas["sampling_min"])
+ self.sampling_hparas["sampling_min"]
torch.rand(size=(1,)).item()
* (self.magnitude - self.sampling_hparas["sampling_min"])
+ self.sampling_hparas["sampling_min"]
)
else:
raise ValueError("sampling_data_type must be either 'int' or 'float'")
Expand All @@ -480,183 +490,3 @@ def __call__(self, video: torch.Tensor) -> torch.Tensor:
else ()
)
return self.transform_fn(video, *level_args, **self.transform_hparas)


def _assert_grid_transform_inputs(
img: Tensor,
matrix: Optional[List[float]],
interpolation: str,
fill: Optional[Union[int, float, List[float]]],
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
) -> None:
def get_dimensions(img: Tensor) -> List[int]:
channels = 1 if img.ndim == 2 else img.shape[-3]
height, width = img.shape[-2:]
return [channels, height, width]

if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor")

if matrix is not None and not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list")

if matrix is not None and len(matrix) != 6:
raise ValueError("Argument matrix should have 6 float values")

if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values")

# Check fill
num_channels = get_dimensions(img)[0]
if (
fill is not None
and isinstance(fill, (tuple, list))
and len(fill) > 1
and len(fill) != num_channels
):
msg = (
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
)
raise ValueError(msg.format(len(fill), num_channels))

if interpolation not in supported_interpolation_modes:
raise ValueError(
f"Interpolation mode '{interpolation}' is unsupported with Tensor input"
)


def _cast_squeeze_in(
img: Tensor, req_dtypes: List[torch.dtype]
) -> Tuple[Tensor, bool, bool, torch.dtype]:
need_squeeze = False
# make image NCHW
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True

out_dtype = img.dtype
need_cast = False
if out_dtype not in req_dtypes:
need_cast = True
req_dtype = req_dtypes[0]
img = img.to(req_dtype)
return img, need_cast, need_squeeze, out_dtype


def _cast_squeeze_out(
img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype
) -> Tensor:
if need_squeeze:
img = img.squeeze(dim=0)

if need_cast:
if out_dtype in (
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
):
# it is better to round before cast
img = torch.round(img)
img = img.to(out_dtype)

return img


def _apply_grid_transform(
img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]
) -> Tensor:

img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype])

if img.shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])

# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
mask = torch.ones(
(img.shape[0], 1, img.shape[2], img.shape[3]),
dtype=img.dtype,
device=img.device,
)
img = torch.cat((img, mask), dim=1)

img = torch.nn.functional.grid_sample(
img, grid, mode=mode, padding_mode="zeros", align_corners=False
)

# Fill with required color
if fill is not None:
mask = img[:, -1:, :, :] # N * 1 * H * W
img = img[:, :-1, :, :] # N * C * H * W
mask = mask.expand_as(img)
fill_list, len_fill = (
(fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1)
)
fill_img = (
torch.tensor(fill_list, dtype=img.dtype, device=img.device)
.view(1, len_fill, 1, 1)
.expand_as(img)
)
if mode == "nearest":
mask = mask < 0.5
img[mask] = fill_img[mask]
else: # 'bilinear'
img = img * mask + (1.0 - mask) * fill_img

img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img


def _gen_affine_grid(
theta: Tensor,
w: int,
h: int,
ow: int,
oh: int,
) -> Tensor:
# https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
# AffineGridGenerator.cpp#L18
# Difference with AffineGridGenerator is that:
# 1) we normalize grid values after applying theta
# 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate

d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
x_grid = torch.linspace(
-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device
)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(
-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device
).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)

rescaled_theta = theta.transpose(1, 2) / torch.tensor(
[0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device
)
output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
return output_grid.view(1, oh, ow, 2)


def affine(
img: Tensor,
matrix: List[float],
interpolation: str = "nearest",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
_assert_grid_transform_inputs(
img, matrix, interpolation, fill, ["nearest", "bilinear"]
)

dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
shape = img.shape
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
return _apply_grid_transform(img, grid, interpolation, fill=fill)

0 comments on commit e3769d1

Please sign in to comment.