Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make Encoder scriptable. #160

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions source/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,15 @@ def buffered_read(fp, buffer_size):
if len(buffer) > 0:
yield buffer


def buffered_arange(max):
if not hasattr(buffered_arange, 'buf'):
buffered_arange.buf = torch.LongTensor()
if max > buffered_arange.buf.numel():
torch.arange(max, out=buffered_arange.buf)
return buffered_arange.buf[:max]


# TODO Do proper padding from the beginning
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
@torch.jit.script
def convert_padding_direction(src_tokens, padding_idx: int, right_to_left: bool=False, left_to_right: bool=False):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if not pad_mask.any():
Expand All @@ -73,7 +71,7 @@ def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left
# already left padded
return src_tokens
max_len = src_tokens.size(1)
range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
range = torch.arange(max_len).type_as(src_tokens).expand_as(src_tokens)
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
if right_to_left:
index = torch.remainder(range - num_pads, max_len)
Expand Down Expand Up @@ -193,6 +191,13 @@ def __init__(
if bidirectional:
self.output_units *= 2

def combine_bidir(self, outs, bsz: int):
return torch.cat([
torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view(1, bsz, self.output_units)
for i in range(self.num_layers)
], dim=0)


def forward(self, src_tokens, src_lengths):
if self.left_pad:
# convert left-padding to right-padding
Expand All @@ -211,30 +216,25 @@ def forward(self, src_tokens, src_lengths):
x = x.transpose(0, 1)

# pack embedded source tokens into a PackedSequence
packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist())

packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths)

# apply LSTM
if self.bidirectional:
state_size = 2 * self.num_layers, bsz, self.hidden_size
else:
state_size = self.num_layers, bsz, self.hidden_size
h0 = x.data.new(*state_size).zero_()
c0 = x.data.new(*state_size).zero_()
h0 = torch.zeros(*state_size)
c0 = torch.zeros(*state_size)
packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0))

# unpack outputs and apply dropout
x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value)
assert list(x.size()) == [seqlen, bsz, self.output_units]

if self.bidirectional:
def combine_bidir(outs):
return torch.cat([
torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view(1, bsz, self.output_units)
for i in range(self.num_layers)
], dim=0)

final_hiddens = combine_bidir(final_hiddens)
final_cells = combine_bidir(final_cells)
final_hiddens = self.combine_bidir(final_hiddens, bsz)
final_cells = self.combine_bidir(final_cells, bsz)

encoder_padding_mask = src_tokens.eq(self.padding_idx).t()

Expand All @@ -248,8 +248,10 @@ def combine_bidir(outs):

return {
'sentemb': sentemb,
'encoder_out': (x, final_hiddens, final_cells),
'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None
'encoder_out': x,
'final_hiddens': final_hiddens,
'final_cells': final_cells,
'encoder_padding_mask': encoder_padding_mask
}


Expand Down