-
Notifications
You must be signed in to change notification settings - Fork 142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add Alibi positional embeddings #462
base: main
Are you sure you want to change the base?
add Alibi positional embeddings #462
Conversation
unit test failure is not related. |
test failure is not related - appears to be rounding issue: |
…w2020/PyTorch_MultiModal into alibi_positional_embeddings
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #462 +/- ##
==========================================
+ Coverage 74.01% 74.13% +0.12%
==========================================
Files 207 207
Lines 14203 14274 +71
==========================================
+ Hits 10512 10582 +70
- Misses 3691 3692 +1
☔ View full report in Codecov by Sentry. |
return self.alibi_mask[..., :curr_seq_len, :curr_seq_len] | ||
|
||
@classmethod | ||
def build_causal_attention_mask(cls, seq_len: int, num_heads: int) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fwiw there is also the get_causal_attention_mask utility (you may even be able to use get_extended_attention_mask
from the same file in lieu of the repeat, it does broadcast to an extra dim for batch size though)
max_seq_len: int, | ||
num_heads: int, | ||
) -> None: | ||
"""recommended usage: create alibi mask before transformer block loop and integrate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is a bit tricky. Kinda similar to RoPE embeddings: integrating this properly will necessitate rethinking some aspects of our transformer implementation. For instance, seems like one assumption here is that our transformer's mask should be float dtype and not bool
@@ -169,3 +170,108 @@ def forward(self, t: Tensor) -> Tensor: | |||
if self.embed_dim % 2 == 1: | |||
embeddings = nn.functional.pad(embeddings, (0, 1)) | |||
return embeddings | |||
|
|||
|
|||
class AlibiPositionEmbeddings(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High level q: if we not using model forward and mostly using class/static methods, why not just define as a function? Offhand I don't see a reason why this needs to be stateful (it's very possible I'm missing something though)
@staticmethod | ||
def get_slopes(num_heads: int) -> List[float]: | ||
"""for n heads, a range from (0,1) and is the geometric sequence | ||
that starts at 2^(-8/n) and uses this same value as its ratio |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for explaining/documenting the magic numbers 🙂
return get_slopes_power_of_2(num_heads) | ||
|
||
# paper authors note that they only trained models that have 2^a heads for some a. | ||
# This has beneficial properties related to input being power of 2. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know what these properties are? Tbh I am confused by this because even if n is a power of 2 some of the ratios will not be rational for n > 8
b = get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ | ||
: num_heads - closest_power_of_2 | ||
] | ||
return [x for pair in zip(b, a) for x in pair] + a[len(b) :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imo this is hard to parse. Agree with @daviswer's comment about returning values in order but could we just do sorted(a+b)? (Maybe I'm missing a tricky case.. if so a comment explaining this would suffice instead)
# paper authors note that they only trained models that have 2^a heads for some a. | ||
# This has beneficial properties related to input being power of 2. | ||
|
||
# Closest power of 2 below is workaround for when num of heads is not power of 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Their method of interpolating is a bit unusual. Maybe explicitly explain that for
Summary:
This PR adds Alibi positional embeddings class. (per the Alibi paper https://arxiv.org/abs/2108.12409)
This generates the Alibi attn mask to be added post QKT/sqrt(k.dim) and replaces the usual sinusoidal type positional embeddings.
The usage is designed to be instantiated outside the transformer block loop based on max_seq_length, and the layers retrieve the attn mask based on current seq length (thus only a single mask buffer needs to be created).
Test plan:
I tested by running in a 200M gpt2 model along with 10% of openwebtext to compare curves between learned embeddings (default in gpt2) and alibi.
I also added a unit test with three tests:
a - shape of the alibi mask
b - verify first head row entry
c - verify last head last row entry
Note that half the mask is -inf, but in trying to use allclose with -inf, they will not match...so I targeted entries that have only real numbers.