Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adding support for FP8 training #218

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open

[WIP] Adding support for FP8 training #218

wants to merge 27 commits into from

Conversation

shahromil16
Copy link
Collaborator

  • Changes made to train.py, model.py, and norms.py to check if Transformer Engine can be imported for P5 instances or H100s and use FP8 for Linear and LayerNorm layers.
  • Minor modifications to main.py for FP8 support

@shahromil16 shahromil16 self-assigned this Feb 21, 2024
@shahromil16 shahromil16 changed the title Adding support for FP8 training [WIP] Adding support for FP8 training Feb 21, 2024
open_lm/model.py Outdated
@@ -117,41 +128,72 @@ def __init__(self, layer_id, args: Params):
super().__init__()
self.n_heads = args.n_heads
self.head_dim = args.dim // args.n_heads
self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False)
self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
if using_te:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of if/else's can we have a single helper function that recursively searches the model and replaces the linears?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another option is having a linear/layernorm module m that is set to either te or nn

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, yes. But that would replace the Linear layer found here which breaks the training. So until a fix is not found, need to isolate that particular layer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we just not recurse for special cases?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missed this earlier, but yeah this might also address my swiglu comment below (since it would recurse and replace the linear within the swiglu)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completed in the latest commit.

open_lm/model.py Outdated
torch.nn.init.trunc_normal_(self.in_proj.weight_tensor.float(), std=std, a=-3 * std, b=3 * std)
# scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better.
std = std / math.sqrt(2 * (self.layer_id + 1))
torch.nn.init.trunc_normal_(self.out_proj.weight_tensor.float(), std=std, a=-3 * std, b=3 * std)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to cast to float? does a float cast happen in place?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont need float cast. Removing that in next commit.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed completely as we are recursively changing NN.Linear to TE.Linear

open_lm/model.py Outdated
eps=args.norm_eps,
)
else:
self.attention_norm = args.norm_type(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just add te.LayerNorm as one of the args.norm_type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this in the latest commit based on presence of TE, TE.LayerNorm or NN.LayerNorm will be considered.

open_lm/norms.py Outdated
@@ -55,7 +67,16 @@ def reset_parameters(self) -> None:
self.bias.zero_()

def forward(self, input: Tensor) -> Tensor:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
if using_te:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@achalddave do u think we should have a seperate class for TeLayerNorm? or do u prefer combining it with existing layer norm

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can have a separate class for TeLayerNorm.

help="Using SMP Flash Attention.",
)
parser.add_argument(
"--sharding-strategy",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this used? if so can we have a more specific name?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also not seeing where --use-smp-flash-attention is used

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used for FP8. Just placeholder flags defaulted to None for Sagemaker Model Parallel.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this to avoid confusion

@@ -202,9 +245,14 @@ def __init__(self, layer_id, args: Params):
elif args.ffn_type == "gelu":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also support fp8 for swiglu above? We can make a copy of the Swiglu class in this file. Here's the source for Swiglu https://github.com/facebookresearch/xformers/blob/7f8c290183344343771f4e1d945a8ce10a9500ff/xformers/ops/swiglu_op.py#L430

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rams16592 seems like the recursive replace linear patten should take care of this automatically. a function like this seems like it would be great and we can exclude certain linears that need to be higher precision for stability. this function has an include field instead of exclude, but hopefully that's easy to flip:
https://github.com/mlfoundations/open_clip/blob/73fa7f03a33da53653f61841eb6d69aef161e521/src/open_clip/utils.py#L65

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied this change in the latest commit. Excluding the last output Linear layer from the conversion to TE Linear as its running into errors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants