Skip to content

Commit

Permalink
add attention aggregation fine-tuning with dynamic contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 30, 2021
1 parent 80f75e9 commit 55acb63
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 2 deletions.
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,41 @@ loss = model(
loss.backward()
```

Finally, there is also a way to use attention aggregation from a set of context embeddings (or a single context embedding). Simply use the `ContextAttentionAdapterWrapper`

```python
import torch
from enformer_pytorch import Enformer
from enformer_pytorch.finetune import ContextAttentionAdapterWrapper

enformer = Enformer(
dim = 1536,
depth = 1,
heads = 8,
target_length = 200,
)

model = ContextAttentionAdapterWrapper(
enformer = enformer,
context_dim = 1024,
heads = 8, # number of heads in the cross attention
dim_head = 64 # dimension per head
).cuda()

seq = torch.randint(0, 4, (1, 196_608 // 2,)).cuda()

target = torch.randn(1, 200, 4).cuda() # 4 tracks
context = torch.randn(4, 16, 1024).cuda() # 4 contexts for the different 'tracks', each with 16 tokens

loss = model(
seq,
context = context,
target = target
)

loss.backward()
```

## Appreciation

Special thanks goes out to <a href="https://www.eleuther.ai/">EleutherAI</a> for providing the resources to retrain the model in an acceptable amount of time
Expand Down
88 changes: 87 additions & 1 deletion enformer_pytorch/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from contextlib import contextmanager
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from enformer_pytorch.enformer_pytorch import Enformer, poisson_loss

def exists(val):
Expand Down Expand Up @@ -34,6 +36,8 @@ def get_enformer_embeddings(model, seq, freeze = False):

# fine-tune wrapper classes

# extra head projection, akin to how human and mouse tracks were trained

class HeadAdapterWrapper(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -65,6 +69,9 @@ def forward(

return poisson_loss(preds, target)

# wrapper that allows one to supply each track with a context dimension
# the context embedding will be projected into the weights and biases of the head linear projection (hypernetwork)

class ContextAdapterWrapper(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -100,3 +107,82 @@ def forward(
return pred

return poisson_loss(pred, target)

# wrapper that does attention aggregation of the context, which can be a list of tokens (batch x seq x dim)

class ContextAttentionAdapterWrapper(nn.Module):
def __init__(
self,
*,
enformer,
context_dim,
heads = 8,
dim_head = 64
):
super().__init__()
assert isinstance(enformer, Enformer)
self.enformer = enformer

self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = heads * dim_head
self.to_queries = nn.Linear(enformer.dim * 2, inner_dim)

self.null_key = nn.Parameter(torch.randn(inner_dim))
self.null_value = nn.Parameter(torch.randn(inner_dim))

self.to_key_values = nn.Linear(context_dim, inner_dim * 2, bias = False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, 1),
Rearrange('c ... 1 -> ... c'),
nn.Softplus()
)

def forward(
self,
seq,
*,
context,
target = None,
freeze_enformer = False
):
h = self.heads
embeddings = get_enformer_embeddings(self.enformer, seq, freeze = freeze_enformer)

# perform cross attention from genetic -> context

if context.ndim == 2:
context = rearrange(context, 'b d -> b 1 d')

q = self.to_queries(embeddings)
k, v = self.to_key_values(context).chunk(2, dim = -1)

null_k, null_v = map(lambda t: repeat(t, 'd -> b 1 d', b = context.shape[0]), (self.null_key, self.null_value))

k = torch.cat((null_k, k), dim = 1)
v = torch.cat((null_v, v), dim = 1)

# split out head

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
sim = einsum('b h i d, c h j d -> b c h i j', q, k) * self.scale

# attention

attn = sim.softmax(dim = -1)

# aggregate

out = einsum('b c h i j, c h j d -> c h i d', attn, v)

out = rearrange(out, 'c h n d -> c n (h d)', h = h)

# combine heads and project / softplus

pred = self.to_out(out)

if not exists(target):
return pred

return poisson_loss(pred, target)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'enformer-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.1.12',
version = '0.1.14',
license='MIT',
description = 'Enformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 55acb63

Please sign in to comment.