Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add padding direction #2121

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 56 additions & 24 deletions test/torchtext_unittest/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,42 +218,74 @@ def _pad_transform(self, test_scripting):

input_1d_tensor = torch.ones(5)
input_2d_tensor = torch.ones((8, 5))
pad_long = transforms.PadTransform(max_length=7, pad_value=0)
pad_long_end = PadTransform(max_length=7, pad_value=0, begin=False)
pad_long_begin = PadTransform(max_length=7, pad_value=0, begin=True)
if test_scripting:
pad_long = torch.jit.script(pad_long)
padded_1d_tensor_actual = pad_long(input_1d_tensor)
padded_1d_tensor_expected = torch.cat([torch.ones(5), torch.zeros(2)])
pad_long_end = torch.jit.script(pad_long_end)
pad_long_begin = torch.jit.script(pad_long_begin)
padded_1d_tensor_actual_end = pad_long_end(input_1d_tensor)
padded_1d_tensor_expected_end = torch.cat([torch.ones(5), torch.zeros(2)])
torch.testing.assert_close(
padded_1d_tensor_actual,
padded_1d_tensor_expected,
msg=f"actual: {padded_1d_tensor_actual}, expected: {padded_1d_tensor_expected}",
padded_1d_tensor_actual_end,
padded_1d_tensor_expected_end,
msg=f"actual: {padded_1d_tensor_actual_end}, expected: {padded_1d_tensor_expected_end}",
)
padded_1d_tensor_actual_begin = pad_long_begin(input_1d_tensor)
padded_1d_tensor_expected_begin = torch.cat([torch.zeros(2), torch.ones(5)])
torch.testing.assert_close(
padded_1d_tensor_actual_begin,
padded_1d_tensor_expected_begin,
msg=f"actual: {padded_1d_tensor_actual_begin}, expected: {padded_1d_tensor_expected_begin}",
)

padded_2d_tensor_actual = pad_long(input_2d_tensor)
padded_2d_tensor_expected = torch.cat([torch.ones(8, 5), torch.zeros(8, 2)], axis=-1)
padded_2d_tensor_actual_end = pad_long_end(input_2d_tensor)
padded_2d_tensor_expected_end = torch.cat([torch.ones(8, 5), torch.zeros(8, 2)], axis=-1)
torch.testing.assert_close(
padded_2d_tensor_actual_end,
padded_2d_tensor_expected_end,
msg=f"actual: {padded_2d_tensor_actual_end}, expected: {padded_2d_tensor_expected_end}",
)
padded_2d_tensor_actual_begin = pad_long_begin(input_2d_tensor)
padded_2d_tensor_expected_begin = torch.cat([torch.zeros(8, 2), torch.ones(8, 5),], axis=-1)
torch.testing.assert_close(
padded_2d_tensor_actual,
padded_2d_tensor_expected,
msg=f"actual: {padded_2d_tensor_actual}, expected: {padded_2d_tensor_expected}",
padded_2d_tensor_actual_begin,
padded_2d_tensor_expected_begin,
msg=f"actual: {padded_2d_tensor_actual_begin}, expected: {padded_2d_tensor_expected_begin}",
)

pad_short = transforms.PadTransform(max_length=3, pad_value=0)
pad_short_end = PadTransform(max_length=3, pad_value=0)
pad_short_begin = PadTransform(max_length=3, pad_value=0, begin=True)
if test_scripting:
pad_short = torch.jit.script(pad_short)
padded_1d_tensor_actual = pad_short(input_1d_tensor)
padded_1d_tensor_expected = input_1d_tensor
pad_short_end = torch.jit.script(pad_short_end)
pad_short_begin = torch.jit.script(pad_short_begin)
padded_1d_tensor_actual_end = pad_short_end(input_1d_tensor)
padded_1d_tensor_expected_end = input_1d_tensor
torch.testing.assert_close(
padded_1d_tensor_actual,
padded_1d_tensor_expected,
msg=f"actual: {padded_1d_tensor_actual}, expected: {padded_1d_tensor_expected}",
padded_1d_tensor_actual_end,
padded_1d_tensor_expected_end,
msg=f"actual: {padded_1d_tensor_actual_end}, expected: {padded_1d_tensor_expected_end}",
)
padded_1d_tensor_actual_begin = pad_short_begin(input_1d_tensor)
padded_1d_tensor_expected_begin = input_1d_tensor
torch.testing.assert_close(
padded_1d_tensor_actual_begin,
padded_1d_tensor_expected_begin,
msg=f"actual: {padded_1d_tensor_actual_begin}, expected: {padded_1d_tensor_expected_begin}",
)

padded_2d_tensor_actual = pad_short(input_2d_tensor)
padded_2d_tensor_expected = input_2d_tensor
padded_2d_tensor_actual_end = pad_short_end(input_2d_tensor)
padded_2d_tensor_expected_end = input_2d_tensor
torch.testing.assert_close(
padded_2d_tensor_actual_end,
padded_2d_tensor_expected_end,
msg=f"actual: {padded_2d_tensor_actual_end}, expected: {padded_2d_tensor_expected_end}",
)
padded_2d_tensor_actual_begin = pad_short_begin(input_2d_tensor)
padded_2d_tensor_expected_begin = input_2d_tensor
torch.testing.assert_close(
padded_2d_tensor_actual,
padded_2d_tensor_expected,
msg=f"actual: {padded_2d_tensor_actual}, expected: {padded_2d_tensor_expected}",
padded_2d_tensor_actual_begin,
padded_2d_tensor_expected_begin,
msg=f"actual: {padded_2d_tensor_actual_begin}, expected: {padded_2d_tensor_expected_begin}",
)

def test_pad_transform(self) -> None:
Expand Down
12 changes: 9 additions & 3 deletions torchtext/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,16 @@ class PadTransform(Module):
:param max_length: Maximum length to pad to
:type max_length: int
:param pad_value: Value to pad the tensor with
:type pad_value: bool
:type pad_value: int
:param begin: Whether to insert pad_value at start or end, defaults to False
:type begin: bool
"""

def __init__(self, max_length: int, pad_value: int) -> None:
def __init__(self, max_length: int, pad_value: int, begin: bool = False) -> None:
super().__init__()
self.max_length = max_length
self.pad_value = float(pad_value)
self.begin = begin

def forward(self, x: Tensor) -> Tensor:
"""
Expand All @@ -255,7 +258,10 @@ def forward(self, x: Tensor) -> Tensor:
max_encoded_length = x.size(-1)
if max_encoded_length < self.max_length:
pad_amount = self.max_length - max_encoded_length
x = torch.nn.functional.pad(x, (0, pad_amount), value=self.pad_value)
if self.begin:
x = torch.nn.functional.pad(x, (pad_amount, 0), value=self.pad_value)
else:
x = torch.nn.functional.pad(x, (0, pad_amount), value=self.pad_value)
return x


Expand Down