-
Notifications
You must be signed in to change notification settings - Fork 142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Parallel_Attention_Blocks (3 of 3) #457
base: main
Are you sure you want to change the base?
Add Parallel_Attention_Blocks (3 of 3) #457
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #457 +/- ##
==========================================
+ Coverage 69.78% 70.12% +0.34%
==========================================
Files 175 177 +2
Lines 11992 12156 +164
==========================================
+ Hits 8369 8525 +156
- Misses 3623 3631 +8
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG overall, @ebsmothers do you mind taking a pass through this as well?
# confirm num Q matches num_heads | ||
assert_expected(num_heads, mha_parallel_attention.num_heads) | ||
|
||
# input_ones = torch.ones(dims, dtype=torch.float) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: rm
fixed_output_shape = torch.Size([1, max_seq_len, embedding_dim]) | ||
|
||
assert_expected(fixed_result_firstrow, attn_output[0][0], rtol=0, atol=1e-4) | ||
assert_expected(fixed_output_shape, attn_output.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers, do we / should we do any additional testing besides verifying the first row of outputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This plus shape should be sufficient. One nitpick would be to assert on something besides the first row (maybe the mean over an axis) just because I have seen cases where the first row is actually correct but others are not (e.g. if there is a bug in masking logic).
) | ||
fixed_output_shape = torch.Size([1, max_seq_len, embedding_dim]) | ||
assert_expected(fixed_output_shape, attn_output.shape) | ||
# print(f"{attn_output[0][0]}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove
rel_pos_bias: Optional[torch.Tensor] = None, | ||
has_causal_mask: bool = False, | ||
) -> torch.Tensor: | ||
"""TODO: No KV cache support yet""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we file an issue for this?
fixed_output_shape = torch.Size([1, max_seq_len, embedding_dim]) | ||
|
||
assert_expected(fixed_result_firstrow, attn_output[0][0], rtol=0, atol=1e-4) | ||
assert_expected(fixed_output_shape, attn_output.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This plus shape should be sufficient. One nitpick would be to assert on something besides the first row (maybe the mean over an axis) just because I have seen cases where the first row is actually correct but others are not (e.g. if there is a bug in masking logic).
return 32 | ||
|
||
@pytest.fixture | ||
def mha_parallel_attention(self, embedding_dim, num_heads, total_layers): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing to be careful about for these unit tests: if you are not explicitly initializing the params of the modules then the test results will be sensitive to the order in which submodules are initialized. In the past we've seen cases where some otherwise no-op change breaks tests just cause of changes in initialization order. We have the util init_weights_with_constant
for this, but the tradeoff is that it also makes the test case a lot more trivial (since all weights are 1s).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point!
Currently the parallel attention blocks have their own full init that automatically happens.
So that covers this concern though I think I should add a comment that this is the assumption so that if that breaks in the future, the reader can quickly ascertain what might be going awry.
# from position_embedding import RotaryEmbedding | ||
|
||
|
||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noob question: is this more efficient than repeat_interleave? I thought in this case the extra memory would be allocated either way
# swiglu | ||
activated_mlp = self.mlp_activation(inner_mlp) * gate | ||
|
||
if self.mlp_dropout.p: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the if statement here? Isn't it just a no-op if p=0.0 anyways?
q, k = self.rotary_emb(q, k, start_pos) | ||
|
||
# group query expansion | ||
def kv_expansion(head: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we either make this a method of the class or a standalone function? I feel the nested function harms readability here
emb_dimension: int, | ||
num_heads: int, | ||
head_dimension: int = None, | ||
mlp_expansion_ratio: float = 2.6875, # 8/3 is param matching |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: any reason not to just use an integer value here and skip the multiplication by ratio?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh there is a reason for this actually - the idea was to mimic the same param count as when most people do mlp_expansion =4, but keep it within a power of 8 regime.
A common mistake made is when people compare swiglu vs say GeLU they will claim swiglu is slower...but what's really happening is that since swiglu uses a gate you have more total params if you leave the same 4.0 multiplication factor and simply drop in swiglu in the mlp.
Therefore, this 2.6875 gives you ~ the same params as activation + mul factor 4, but within the power of 8 regime (this was from a paper that showed this gave you a slight efficiency gain...can't remember if it was power of 8 or 16 but something like that). (Hence the comment that exact match is 8/3).
# input_ones = torch.ones(dims, dtype=torch.float) | ||
|
||
x = torch.randint(0, 256, (1, max_seq_len, embedding_dim)) # bs =1, | ||
attn_output = mha_parallel_attention(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would also consider adding a test case covering mask and/or rel_pos_bias args
|
||
# input_ones = torch.ones(dims, dtype=torch.float) | ||
|
||
x = torch.randint(0, 256, (1, max_seq_len, embedding_dim)) # bs =1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can probably just make this a fixture (since I think it's used in all test cases)
|
||
self.num_q = 1 | ||
|
||
self.in_proj_dims = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add a comment either here or on self.in_proj definition about the fusing you're doing. I think fusing of the MLP and gate could be a bit unusual for those unfamiliar with parallel attention
|
||
from torchmultimodal.modules.layers.normalizations import RMSNorm | ||
|
||
# from position_embedding import RotaryEmbedding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this?
Summary:
This PR adds the main and final part for upstreaming Parallel Attention Blocks, specifically the actual Parallel Attn Blocks class itself.
This supports both MHA, MQA and GQA attention head setups.
RMSNorm has already landed and so can be used.
Cross Attention has been removed as requested.
Test plan:
Added 3 unit tests - one each for Parallel Attn Blocks using MHA, MQA and GQA.
Within each test, the number of query heads and number of KV heads are tested to ensure appropriate head counts (i.e. for GQA, num_KV heads is set to 2 and then verified.
From there, a Parallel Attn Layer runs a forward pass on a fixed single input tensor and the first row of the output is checked as is the attn_output shape.