-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Adding support for FP8 training #218
base: main
Are you sure you want to change the base?
Conversation
open_lm/model.py
Outdated
@@ -117,41 +128,72 @@ def __init__(self, layer_id, args: Params): | |||
super().__init__() | |||
self.n_heads = args.n_heads | |||
self.head_dim = args.dim // args.n_heads | |||
self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False) | |||
self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) | |||
if using_te: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of if/else's can we have a single helper function that recursively searches the model and replaces the linears?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another option is having a linear/layernorm module m
that is set to either te
or nn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, yes. But that would replace the Linear layer found here which breaks the training. So until a fix is not found, need to isolate that particular layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't we just not recurse for special cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missed this earlier, but yeah this might also address my swiglu comment below (since it would recurse and replace the linear within the swiglu)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Completed in the latest commit.
open_lm/model.py
Outdated
torch.nn.init.trunc_normal_(self.in_proj.weight_tensor.float(), std=std, a=-3 * std, b=3 * std) | ||
# scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better. | ||
std = std / math.sqrt(2 * (self.layer_id + 1)) | ||
torch.nn.init.trunc_normal_(self.out_proj.weight_tensor.float(), std=std, a=-3 * std, b=3 * std) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to cast to float? does a float cast happen in place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We dont need float cast. Removing that in next commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed completely as we are recursively changing NN.Linear to TE.Linear
open_lm/model.py
Outdated
eps=args.norm_eps, | ||
) | ||
else: | ||
self.attention_norm = args.norm_type( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just add te.LayerNorm as one of the args.norm_type
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this in the latest commit based on presence of TE, TE.LayerNorm or NN.LayerNorm will be considered.
open_lm/norms.py
Outdated
@@ -55,7 +67,16 @@ def reset_parameters(self) -> None: | |||
self.bias.zero_() | |||
|
|||
def forward(self, input: Tensor) -> Tensor: | |||
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) | |||
if using_te: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@achalddave do u think we should have a seperate class for TeLayerNorm? or do u prefer combining it with existing layer norm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can have a separate class for TeLayerNorm.
open_lm/params.py
Outdated
help="Using SMP Flash Attention.", | ||
) | ||
parser.add_argument( | ||
"--sharding-strategy", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this used? if so can we have a more specific name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also not seeing where --use-smp-flash-attention
is used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not used for FP8. Just placeholder flags defaulted to None
for Sagemaker Model Parallel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed this to avoid confusion
@@ -202,9 +245,14 @@ def __init__(self, layer_id, args: Params): | |||
elif args.ffn_type == "gelu": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we also support fp8 for swiglu above? We can make a copy of the Swiglu class in this file. Here's the source for Swiglu https://github.com/facebookresearch/xformers/blob/7f8c290183344343771f4e1d945a8ce10a9500ff/xformers/ops/swiglu_op.py#L430
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rams16592 seems like the recursive replace linear patten should take care of this automatically. a function like this seems like it would be great and we can exclude certain linears that need to be higher precision for stability. this function has an include field instead of exclude, but hopefully that's easy to flip:
https://github.com/mlfoundations/open_clip/blob/73fa7f03a33da53653f61841eb6d69aef161e521/src/open_clip/utils.py#L65
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Applied this change in the latest commit. Excluding the last output Linear layer from the conversion to TE Linear as its running into errors.
…til natively supports FSDP
train.py
,model.py
, andnorms.py
to check if Transformer Engine can be imported for P5 instances or H100s and use FP8 for Linear and LayerNorm layers.main.py
for FP8 support