Skip to content

Commit

Permalink
allow variable lengthed contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 30, 2021
1 parent f019818 commit 7928be1
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,12 @@ seq = torch.randint(0, 4, (1, 196_608 // 2,)).cuda()
target = torch.randn(1, 200, 4).cuda() # 4 tracks
context = torch.randn(4, 16, 1024).cuda() # 4 contexts for the different 'tracks', each with 16 tokens

context_mask = torch.ones(4, 16).bool().cuda() # optional context mask, in example, include all context tokens

loss = model(
seq,
context = context,
context_mask = context_mask,
target = target
)

Expand Down
8 changes: 8 additions & 0 deletions enformer_pytorch/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def forward(
seq,
*,
context,
context_mask = None,
target = None,
freeze_enformer = False
):
Expand Down Expand Up @@ -183,6 +184,13 @@ def forward(
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
sim = einsum('b h i d, c h j d -> b c h i j', q, k) * self.scale

# masking

if exists(context_mask):
context_mask = F.pad(context_mask, (1, 0), value = True)
context_mask =rearrange(context_mask, 'b j -> b 1 1 1 j')
sim = sim.masked_fill(~context_mask, -torch.finfo(sim.dtype).max)

# attention

attn = sim.softmax(dim = -1)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'enformer-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.1.15',
version = '0.1.16',
license='MIT',
description = 'Enformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 7928be1

Please sign in to comment.