diff --git a/test/torchtext_unittest/test_transforms.py b/test/torchtext_unittest/test_transforms.py index 618cbca38f..c0cf8171c8 100644 --- a/test/torchtext_unittest/test_transforms.py +++ b/test/torchtext_unittest/test_transforms.py @@ -218,42 +218,74 @@ def _pad_transform(self, test_scripting): input_1d_tensor = torch.ones(5) input_2d_tensor = torch.ones((8, 5)) - pad_long = transforms.PadTransform(max_length=7, pad_value=0) + pad_long_end = PadTransform(max_length=7, pad_value=0, begin=False) + pad_long_begin = PadTransform(max_length=7, pad_value=0, begin=True) if test_scripting: - pad_long = torch.jit.script(pad_long) - padded_1d_tensor_actual = pad_long(input_1d_tensor) - padded_1d_tensor_expected = torch.cat([torch.ones(5), torch.zeros(2)]) + pad_long_end = torch.jit.script(pad_long_end) + pad_long_begin = torch.jit.script(pad_long_begin) + padded_1d_tensor_actual_end = pad_long_end(input_1d_tensor) + padded_1d_tensor_expected_end = torch.cat([torch.ones(5), torch.zeros(2)]) torch.testing.assert_close( - padded_1d_tensor_actual, - padded_1d_tensor_expected, - msg=f"actual: {padded_1d_tensor_actual}, expected: {padded_1d_tensor_expected}", + padded_1d_tensor_actual_end, + padded_1d_tensor_expected_end, + msg=f"actual: {padded_1d_tensor_actual_end}, expected: {padded_1d_tensor_expected_end}", + ) + padded_1d_tensor_actual_begin = pad_long_begin(input_1d_tensor) + padded_1d_tensor_expected_begin = torch.cat([torch.zeros(2), torch.ones(5)]) + torch.testing.assert_close( + padded_1d_tensor_actual_begin, + padded_1d_tensor_expected_begin, + msg=f"actual: {padded_1d_tensor_actual_begin}, expected: {padded_1d_tensor_expected_begin}", ) - padded_2d_tensor_actual = pad_long(input_2d_tensor) - padded_2d_tensor_expected = torch.cat([torch.ones(8, 5), torch.zeros(8, 2)], axis=-1) + padded_2d_tensor_actual_end = pad_long_end(input_2d_tensor) + padded_2d_tensor_expected_end = torch.cat([torch.ones(8, 5), torch.zeros(8, 2)], axis=-1) + torch.testing.assert_close( + padded_2d_tensor_actual_end, + padded_2d_tensor_expected_end, + msg=f"actual: {padded_2d_tensor_actual_end}, expected: {padded_2d_tensor_expected_end}", + ) + padded_2d_tensor_actual_begin = pad_long_begin(input_2d_tensor) + padded_2d_tensor_expected_begin = torch.cat([torch.zeros(8, 2), torch.ones(8, 5),], axis=-1) torch.testing.assert_close( - padded_2d_tensor_actual, - padded_2d_tensor_expected, - msg=f"actual: {padded_2d_tensor_actual}, expected: {padded_2d_tensor_expected}", + padded_2d_tensor_actual_begin, + padded_2d_tensor_expected_begin, + msg=f"actual: {padded_2d_tensor_actual_begin}, expected: {padded_2d_tensor_expected_begin}", ) - pad_short = transforms.PadTransform(max_length=3, pad_value=0) + pad_short_end = PadTransform(max_length=3, pad_value=0) + pad_short_begin = PadTransform(max_length=3, pad_value=0, begin=True) if test_scripting: - pad_short = torch.jit.script(pad_short) - padded_1d_tensor_actual = pad_short(input_1d_tensor) - padded_1d_tensor_expected = input_1d_tensor + pad_short_end = torch.jit.script(pad_short_end) + pad_short_begin = torch.jit.script(pad_short_begin) + padded_1d_tensor_actual_end = pad_short_end(input_1d_tensor) + padded_1d_tensor_expected_end = input_1d_tensor torch.testing.assert_close( - padded_1d_tensor_actual, - padded_1d_tensor_expected, - msg=f"actual: {padded_1d_tensor_actual}, expected: {padded_1d_tensor_expected}", + padded_1d_tensor_actual_end, + padded_1d_tensor_expected_end, + msg=f"actual: {padded_1d_tensor_actual_end}, expected: {padded_1d_tensor_expected_end}", + ) + padded_1d_tensor_actual_begin = pad_short_begin(input_1d_tensor) + padded_1d_tensor_expected_begin = input_1d_tensor + torch.testing.assert_close( + padded_1d_tensor_actual_begin, + padded_1d_tensor_expected_begin, + msg=f"actual: {padded_1d_tensor_actual_begin}, expected: {padded_1d_tensor_expected_begin}", ) - padded_2d_tensor_actual = pad_short(input_2d_tensor) - padded_2d_tensor_expected = input_2d_tensor + padded_2d_tensor_actual_end = pad_short_end(input_2d_tensor) + padded_2d_tensor_expected_end = input_2d_tensor + torch.testing.assert_close( + padded_2d_tensor_actual_end, + padded_2d_tensor_expected_end, + msg=f"actual: {padded_2d_tensor_actual_end}, expected: {padded_2d_tensor_expected_end}", + ) + padded_2d_tensor_actual_begin = pad_short_begin(input_2d_tensor) + padded_2d_tensor_expected_begin = input_2d_tensor torch.testing.assert_close( - padded_2d_tensor_actual, - padded_2d_tensor_expected, - msg=f"actual: {padded_2d_tensor_actual}, expected: {padded_2d_tensor_expected}", + padded_2d_tensor_actual_begin, + padded_2d_tensor_expected_begin, + msg=f"actual: {padded_2d_tensor_actual_begin}, expected: {padded_2d_tensor_expected_begin}", ) def test_pad_transform(self) -> None: diff --git a/torchtext/transforms.py b/torchtext/transforms.py index 4684d58080..e57f643431 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -237,13 +237,16 @@ class PadTransform(Module): :param max_length: Maximum length to pad to :type max_length: int :param pad_value: Value to pad the tensor with - :type pad_value: bool + :type pad_value: int + :param begin: Whether to insert pad_value at start or end, defaults to False + :type begin: bool """ - def __init__(self, max_length: int, pad_value: int) -> None: + def __init__(self, max_length: int, pad_value: int, begin: bool = False) -> None: super().__init__() self.max_length = max_length self.pad_value = float(pad_value) + self.begin = begin def forward(self, x: Tensor) -> Tensor: """ @@ -255,7 +258,10 @@ def forward(self, x: Tensor) -> Tensor: max_encoded_length = x.size(-1) if max_encoded_length < self.max_length: pad_amount = self.max_length - max_encoded_length - x = torch.nn.functional.pad(x, (0, pad_amount), value=self.pad_value) + if self.begin: + x = torch.nn.functional.pad(x, (pad_amount, 0), value=self.pad_value) + else: + x = torch.nn.functional.pad(x, (0, pad_amount), value=self.pad_value) return x