Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Alibi positional embeddings #462

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion tests/modules/layers/test_position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@
import pytest

import torch
from tests.test_utils import assert_expected
from tests.test_utils import assert_expected, set_rng_seed
from torch import nn
from torchmultimodal.modules.layers.position_embedding import (
AlibiPositionEmbeddings,
BroadcastedPositionEmbedding,
SinusoidalPositionEmbeddings,
)


@pytest.fixture(autouse=True)
def random():
set_rng_seed(2023)


class TestBroadcastedPositionEmbedding:
@pytest.fixture(scope="class")
def pos_emb(self):
Expand Down Expand Up @@ -112,3 +118,71 @@ def test_forward(self, data, emb):
actual = emb(data)
expected = torch.Size([3, 5])
assert_expected(actual.shape, expected)


class TestAlibiPositionEmbedding:
@pytest.fixture
def max_seq_len(self):
return 16

@pytest.fixture
def embedding_dim(self):
return 32

@pytest.fixture
def num_heads(self):
return 8

def test_alibi_mask(
self,
max_seq_len,
num_heads,
):
alibi_class = AlibiPositionEmbeddings(
max_seq_len=max_seq_len, num_heads=num_heads
)
base_mask = alibi_class.get_attention_mask(max_seq_len)

# verify mask shape
expected_shape = torch.Size((num_heads, max_seq_len, max_seq_len))
assert_expected(base_mask.shape, expected_shape)

# verify alibi mask components
expected_last_head_row = torch.tensor(
[
-0.0586,
-0.0547,
-0.0508,
-0.0469,
-0.0430,
-0.0391,
-0.0352,
-0.0312,
-0.0273,
-0.0234,
-0.0195,
-0.0156,
-0.0117,
-0.0078,
-0.0039,
0.0000,
]
)

expected_first_head_first_row_first_entry = torch.tensor(
0.0000,
)

assert_expected(
base_mask[0][0][0],
expected_first_head_first_row_first_entry,
rtol=0,
atol=1e-4,
)

assert_expected(
base_mask[num_heads - 1][max_seq_len - 1],
expected_last_head_row,
rtol=0,
atol=1e-4,
)
106 changes: 105 additions & 1 deletion torchmultimodal/modules/layers/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
# LICENSE file in the root directory of this source tree.

import itertools
from typing import Tuple
import math
from typing import List, Tuple

import torch
from torch import nn, Tensor
Expand Down Expand Up @@ -169,3 +170,106 @@ def forward(self, t: Tensor) -> Tensor:
if self.embed_dim % 2 == 1:
embeddings = nn.functional.pad(embeddings, (0, 1))
return embeddings


class AlibiPositionEmbeddings(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High level q: if we not using model forward and mostly using class/static methods, why not just define as a function? Offhand I don't see a reason why this needs to be stateful (it's very possible I'm missing something though)

"""Attention with Linear Biases (ALiBi)

# Softmax(qiKT + m · [-(i - 1), ..., -2, -1, 0]),
where m = fixed specific slope per head

as proposed in:
https://arxiv.org/abs/2108.12409
Train Short, Test Long: Attention with Linear Biases
Enables Input Length Extrapolation

derived from Ofir Press (alibi author) codebase:
https://github.com/ofirpress/attention_with_linear_biases

"""

def __init__(
self,
max_seq_len: int,
num_heads: int,
) -> None:
"""recommended usage: create alibi mask before transformer block loop and integrate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is a bit tricky. Kinda similar to RoPE embeddings: integrating this properly will necessitate rethinking some aspects of our transformer implementation. For instance, seems like one assumption here is that our transformer's mask should be float dtype and not bool

Alibi should be applied after the sqrt scaling of the attention values

Example:
before Transformer block loop:
from alibi_embeddings import AlibiPE
self.alibi = AlibiPE(config.max_seq_len, config.num_heads)
pass a reference to the alibi class to each transformer layer
then in forward of transformer layer:
alibi_mask = self.alibi.get_attention_mask(N) # N = seq length of this batch
...
attn = q @ k.transpose( -2, -1)
att *= 1.0 / math.sqrt(k.size(-1))
att += alibi_mask

"""
super().__init__()

self.num_heads = num_heads
self.max_seq_len = max_seq_len

self.causal_mask = self.build_causal_attention_mask(
self.max_seq_len, self.num_heads
)
self.alibi_mask_base = self.build_alibi_mask(self.max_seq_len, self.num_heads)
self.decoder_mask = self.causal_mask + self.alibi_mask_base
self.register_buffer("alibi_mask", self.decoder_mask, persistent=False)

def get_attention_mask(self, curr_seq_len: int) -> torch.Tensor:
"""returns the alibi mask, clipped to the current batch seq len"""
return self.alibi_mask[..., :curr_seq_len, :curr_seq_len]

@classmethod
def build_causal_attention_mask(cls, seq_len: int, num_heads: int) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw there is also the get_causal_attention_mask utility (you may even be able to use get_extended_attention_mask from the same file in lieu of the repeat, it does broadcast to an extra dim for batch size though)

"""builds a generic causal attention mask"""
causal_mask = torch.triu(
torch.ones(seq_len, seq_len) * float("-inf"), diagonal=1
)
attn_mask = causal_mask.repeat(num_heads, 1, 1)
return attn_mask

@classmethod
def build_alibi_mask(cls, seq_len: int, num_heads: int) -> torch.Tensor:
"""generate the alibi mask by computing a distance bias matrix multiplied by each head's m (slope)"""
distance_bias_matrix = -torch.abs(
torch.arange(seq_len) - torch.arange(seq_len).view(-1, 1)
)
slope_per_head = Tensor(cls.get_slopes(num_heads)).view(-1, 1, 1)
alibi_mask = distance_bias_matrix * slope_per_head
return alibi_mask

@staticmethod
def get_slopes(num_heads: int) -> List[float]:
"""for n heads, a range from (0,1) and is the geometric sequence
that starts at 2^(-8/n) and uses this same value as its ratio
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for explaining/documenting the magic numbers 🙂


example: num_heads =4
result: [0.25, 0.0625, 0.015625, 0.00390625]

"""

def get_slopes_power_of_2(n: int) -> List[float]:
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]

if math.log2(num_heads).is_integer():
return get_slopes_power_of_2(num_heads)

# paper authors note that they only trained models that have 2^a heads for some a.
# This has beneficial properties related to input being power of 2.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know what these properties are? Tbh I am confused by this because even if n is a power of 2 some of the ratios will not be rational for n > 8

# Closest power of 2 below is workaround for when num of heads is not power of 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Their method of interpolating is a bit unusual. Maybe explicitly explain that for $num \textunderscore heads=2^N + k$ they are splicing the geometric series with ratio $2^{-\frac{8}{N}}$ with the first $2k$ elements of the geometric series with ratio $2^{-\frac{8}{N+1}}$ (assuming I am even understanding it correctly 😅)


closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
: num_heads - closest_power_of_2
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved
]
)
Loading