-
Notifications
You must be signed in to change notification settings - Fork 142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add Alibi positional embeddings #462
base: main
Are you sure you want to change the base?
Changes from 10 commits
6eb62d9
0c52d11
f263b23
926b3fb
8aadfb4
d409abc
7d8ff19
82876cf
2289962
07c21a2
ece9a0e
13b4270
1673422
52f7afa
c54548f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,8 @@ | |
# LICENSE file in the root directory of this source tree. | ||
|
||
import itertools | ||
from typing import Tuple | ||
import math | ||
from typing import List, Tuple | ||
|
||
import torch | ||
from torch import nn, Tensor | ||
|
@@ -169,3 +170,106 @@ def forward(self, t: Tensor) -> Tensor: | |
if self.embed_dim % 2 == 1: | ||
embeddings = nn.functional.pad(embeddings, (0, 1)) | ||
return embeddings | ||
|
||
|
||
class AlibiPositionEmbeddings(nn.Module): | ||
"""Attention with Linear Biases (ALiBi) | ||
|
||
# Softmax(qiKT + m · [-(i - 1), ..., -2, -1, 0]), | ||
where m = fixed specific slope per head | ||
|
||
as proposed in: | ||
https://arxiv.org/abs/2108.12409 | ||
Train Short, Test Long: Attention with Linear Biases | ||
Enables Input Length Extrapolation | ||
|
||
derived from Ofir Press (alibi author) codebase: | ||
https://github.com/ofirpress/attention_with_linear_biases | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
max_seq_len: int, | ||
num_heads: int, | ||
) -> None: | ||
"""recommended usage: create alibi mask before transformer block loop and integrate | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this is a bit tricky. Kinda similar to RoPE embeddings: integrating this properly will necessitate rethinking some aspects of our transformer implementation. For instance, seems like one assumption here is that our transformer's mask should be float dtype and not bool |
||
Alibi should be applied after the sqrt scaling of the attention values | ||
|
||
Example: | ||
before Transformer block loop: | ||
from alibi_embeddings import AlibiPE | ||
self.alibi = AlibiPE(config.max_seq_len, config.num_heads) | ||
pass a reference to the alibi class to each transformer layer | ||
then in forward of transformer layer: | ||
alibi_mask = self.alibi.get_attention_mask(N) # N = seq length of this batch | ||
... | ||
attn = q @ k.transpose( -2, -1) | ||
att *= 1.0 / math.sqrt(k.size(-1)) | ||
att += alibi_mask | ||
|
||
""" | ||
super().__init__() | ||
|
||
self.num_heads = num_heads | ||
self.max_seq_len = max_seq_len | ||
|
||
self.causal_mask = self.build_causal_attention_mask( | ||
self.max_seq_len, self.num_heads | ||
) | ||
self.alibi_mask_base = self.build_alibi_mask(self.max_seq_len, self.num_heads) | ||
self.decoder_mask = self.causal_mask + self.alibi_mask_base | ||
self.register_buffer("alibi_mask", self.decoder_mask, persistent=False) | ||
|
||
def get_attention_mask(self, curr_seq_len: int) -> torch.Tensor: | ||
"""returns the alibi mask, clipped to the current batch seq len""" | ||
return self.alibi_mask[..., :curr_seq_len, :curr_seq_len] | ||
|
||
@classmethod | ||
def build_causal_attention_mask(cls, seq_len: int, num_heads: int) -> torch.Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fwiw there is also the get_causal_attention_mask utility (you may even be able to use |
||
"""builds a generic causal attention mask""" | ||
causal_mask = torch.triu( | ||
torch.ones(seq_len, seq_len) * float("-inf"), diagonal=1 | ||
) | ||
attn_mask = causal_mask.repeat(num_heads, 1, 1) | ||
return attn_mask | ||
|
||
@classmethod | ||
def build_alibi_mask(cls, seq_len: int, num_heads: int) -> torch.Tensor: | ||
"""generate the alibi mask by computing a distance bias matrix multiplied by each head's m (slope)""" | ||
distance_bias_matrix = -torch.abs( | ||
torch.arange(seq_len) - torch.arange(seq_len).view(-1, 1) | ||
) | ||
slope_per_head = Tensor(cls.get_slopes(num_heads)).view(-1, 1, 1) | ||
alibi_mask = distance_bias_matrix * slope_per_head | ||
return alibi_mask | ||
|
||
@staticmethod | ||
def get_slopes(num_heads: int) -> List[float]: | ||
"""for n heads, a range from (0,1) and is the geometric sequence | ||
that starts at 2^(-8/n) and uses this same value as its ratio | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for explaining/documenting the magic numbers 🙂 |
||
|
||
example: num_heads =4 | ||
result: [0.25, 0.0625, 0.015625, 0.00390625] | ||
|
||
""" | ||
|
||
def get_slopes_power_of_2(n: int) -> List[float]: | ||
start = 2 ** (-(2 ** -(math.log2(n) - 3))) | ||
ratio = start | ||
return [start * ratio**i for i in range(n)] | ||
|
||
if math.log2(num_heads).is_integer(): | ||
return get_slopes_power_of_2(num_heads) | ||
|
||
# paper authors note that they only trained models that have 2^a heads for some a. | ||
# This has beneficial properties related to input being power of 2. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you know what these properties are? Tbh I am confused by this because even if n is a power of 2 some of the ratios will not be rational for n > 8 |
||
# Closest power of 2 below is workaround for when num of heads is not power of 2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Their method of interpolating is a bit unusual. Maybe explicitly explain that for |
||
|
||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) | ||
return ( | ||
get_slopes_power_of_2(closest_power_of_2) | ||
+ get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ | ||
: num_heads - closest_power_of_2 | ||
lessw2020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High level q: if we not using model forward and mostly using class/static methods, why not just define as a function? Offhand I don't see a reason why this needs to be stateful (it's very possible I'm missing something though)