Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adding support for FP8 training #218

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7e4dc10
Adding support for FP8 training
romilshahtri Feb 21, 2024
e8cad2a
Linter changes
shahromil16 Feb 21, 2024
0514087
Converting all Linears to TE Linears except output Linear
shahromil16 Feb 29, 2024
ff8e8c8
Fix linter errors
shahromil16 Mar 1, 2024
298471f
Merge remote-tracking branch 'origin/main' into feature/fp8
shahromil16 Apr 12, 2024
9224b0e
Rebase from main and update FP8 changes
shahromil16 Apr 12, 2024
4563be2
Linter changes
shahromil16 Apr 12, 2024
937927a
Adding asserts for FP8
shahromil16 Apr 13, 2024
1594b9f
Asserts for FP8
shahromil16 Apr 13, 2024
740f2b1
Predefine all_gpus for TE
shahromil16 Apr 13, 2024
ccc7eef
Merge remote-tracking branch 'origin/main' into feature/fp8
shahromil16 Apr 17, 2024
3713f61
Remove if/else for fp8 checks
shahromil16 Apr 17, 2024
14e8278
Remove extra asserts
shahromil16 Apr 17, 2024
e572510
Removing unused deps
shahromil16 Apr 17, 2024
8350cb9
Update routine for converting NN layers to TE equivalents
shahromil16 Apr 17, 2024
a907b3c
Merge remote-tracking branch 'origin/main' into feature/fp8
shahromil16 Apr 24, 2024
4e582a0
Update FP8 flags and checks for layers
shahromil16 Apr 24, 2024
cdb0cf7
Linter checks
shahromil16 Apr 24, 2024
afb46cb
Add checks for autocast function
shahromil16 Apr 24, 2024
00c9e5b
Minor edit to model
shahromil16 Apr 24, 2024
40c7a6d
Adding default args as Params to SwiGLUTorch
shahromil16 Apr 24, 2024
8dbd1d8
Linter fixes
shahromil16 Apr 24, 2024
ec91746
Adding Torch Attention TE
shahromil16 Apr 30, 2024
afb7a66
Merge remote-tracking branch 'origin/main' into feature/fp8
shahromil16 May 6, 2024
29000f3
Fixing FP8+FSDP memory issues by removing FP8 from all activations un…
shahromil16 May 6, 2024
936cd9a
Merge remote-tracking branch 'origin/main' into feature/fp8
shahromil16 Jun 12, 2024
aca1b75
Updating deps and config
shahromil16 Jun 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 61 additions & 6 deletions open_lm/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
from torch.nn import functional as F
import xformers.ops as xops

# Adding flag if using TE FP8
using_te = False
try:
import transformer_engine.pytorch as te

using_te = True
except ImportError as ie:
using_te = False


def get_rectangular_causal_mask(shape, q_seq_len, k_seq_len, device, dtype):
"""Create a rectangular causal mask.
Expand Down Expand Up @@ -137,6 +146,55 @@ def torch_attn(queries, keys, values, is_causal, attention_mask=None):
)


def torch_attn_te(queries, keys, values, is_causal, attention_mask=None):
_, num_q_heads, _, _ = queries.shape
_, _, hidden_dim_k, _ = values.shape
scaleddotproductattn_module = te.DotProductAttention(num_attention_heads=num_q_heads, kv_channels=hidden_dim_k)
if is_causal and keys.shape[1] > queries.shape[1] > 1:
q_seq_len = queries.shape[1]
k_seq_len = keys.shape[1]
# Same as above, we would like to use:
# mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask().materialize((1, 1, q_seq_len, k_seq_len), queries.dtype, queries.device)
mask = get_rectangular_causal_mask((1, 1), q_seq_len, k_seq_len, queries.device, queries.dtype)
if attention_mask is not None:
apply_attention_mask_(mask, attention_mask, queries_dtype=queries.dtype)
return (
scaleddotproductattn_module(
queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attention_mask=mask
)
.transpose(1, 2)
.contiguous()
)
else:
if attention_mask is None:
bias = None
# If we only have one query, assume we don't need to be in causal mode (can attend to all keys).
if queries.shape == 1:
is_causal = False
else:
if not is_causal:
raise NotImplementedError("attention_mask with is_causal=False is not yet implemented.")
# Build causal mask that assumes queries are in the end of the sequence.
batch, q_seq_len, heads, _ = queries.shape
k_seq_len = keys.shape[1]
bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype)
if attention_mask is not None:
apply_attention_mask_(bias, attention_mask, queries_dtype=queries.dtype)
# We apply causal mask in attention instead of using is_causal=True.
is_causal = False
return (
scaleddotproductattn_module(
queries.transpose(1, 2),
keys.transpose(1, 2),
values.transpose(1, 2),
attention_mask=bias,
attn_mask_type="causal" if is_causal else None,
)
.transpose(1, 2)
.contiguous()
)


ATTN_ACTIVATIONS = {
"relu": F.relu,
"relu_squared": lambda x: torch.pow(F.relu(x), 2),
Expand Down Expand Up @@ -189,12 +247,7 @@ def custom_attn(
return torch.einsum("bhqk,bkhd->bqhd", attn_weight, values)


def get_attn_func(
attn_name,
attn_activation=None,
attn_seq_scalar=None,
alpha=None,
):
def get_attn_func(attn_name, attn_activation=None, attn_seq_scalar=None, alpha=None, use_fp8=False):
if attn_name == "auto":
return xformers_attn if torch.cuda.is_available() else torch_attn
elif attn_name == "xformers_attn":
Expand All @@ -206,6 +259,8 @@ def get_attn_func(
# call .contiguous() on the output tensor. [#188]
return lambda *args, **kwargs: xformers_attn(*args, **kwargs).contiguous()
elif attn_name == "torch_attn":
# if using_te and use_fp8:
# return torch_attn_te
return torch_attn
elif attn_name == "custom_attn":
assert (
Expand Down
7 changes: 5 additions & 2 deletions open_lm/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def init_distributed_device(args):
args.world_size = 1
args.rank = 0 # global rank
args.local_rank = 0
args.world_group = None
# For testing, allow forcing distributed mode to test distributed code path even on one gpu.
if is_using_distributed() or args.force_distributed:
if "SLURM_PROCID" in os.environ:
Expand All @@ -74,7 +75,7 @@ def init_distributed_device(args):
os.environ["LOCAL_RANK"] = str(args.local_rank)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
torch.distributed.init_process_group(
args.world_group = torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
Expand All @@ -85,7 +86,9 @@ def init_distributed_device(args):
# Note that this currently assumes that the world size is all gpus in a node.
assert args.preset_world_size is None, "--preset_world_size with torchrun is not currently supported."
args.local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url)
args.world_group = torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url
)
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
args.distributed = True
Expand Down
11 changes: 9 additions & 2 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
terminate_sync_process,
)


LATEST_CHECKPOINT_NAME = "epoch_latest.pt"


Expand Down Expand Up @@ -466,13 +465,18 @@ def main(args):

random_seed(args.seed, 0)

tensor_parallel_group = None
if args.use_fp8:
tensor_parallel_group = torch.distributed.new_group(ranks=[0], backend="nccl")
logging.info("Using FP8 to run training.")

model = None
if args.hf_model is not None:
model = create_wrapped_hf_model(args)
else:
# Optional: Use meta device
with torch.device("meta" if args.experimental_meta_device and args.fsdp else args.device):
model = create_model(args)
model = create_model(args, tensor_parallel_group)

args.vocab_size = model.vocab_size
args.seq_len = model.seq_len
Expand Down Expand Up @@ -548,8 +552,10 @@ def main(args):

# Initialize FSDP. Use the same seed across workers to ensure reset_parameters is the same across workers.
random_seed(args.seed, rank=0)

model = FSDP(
model,
process_group=args.world_group,
auto_wrap_policy=transformer_auto_wrapper_policy,
device_id=device,
mixed_precision=mp_policy,
Expand Down Expand Up @@ -832,6 +838,7 @@ def main(args):
total_steps=total_steps,
args=args,
tb_writer=writer,
data_parallel_group=args.world_group,
)

if args.distributed:
Expand Down
77 changes: 60 additions & 17 deletions open_lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,24 @@
except ImportError:
MambaLMHeadModel = None

# Adding flag if using TE FP8
using_te = False
LinearTE = nn.Linear
try:
import transformer_engine.pytorch as te

using_te = True

class LinearTE(te.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def forward(self, inp: torch.Tensor, is_first_microbatch: bool = True):
return super().forward(inp, is_first_microbatch=True)

except ImportError as ie:
using_te = False

# from openclip
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
Expand Down Expand Up @@ -86,7 +104,9 @@ class Params:
seq_len: int = 2048
post_embed_norm: bool = False
weight_tying: bool = False
norm_type: nn.Module = nn.LayerNorm
norm_type: nn.Module = te.LayerNorm if using_te else nn.LayerNorm
linear_type: nn.Module = LinearTE if using_te else nn.Linear
te_device: str = "cuda" if using_te else None
attn_func: Callable = xformers_attn if torch.cuda.is_available() else torch_attn
apply_qk_norm: bool = False
moe_loss_weight: float = 0.1
Expand Down Expand Up @@ -119,8 +139,8 @@ def __init__(self, layer_id, args: Params):
super().__init__()
self.n_heads = args.n_heads
self.head_dim = args.dim // args.n_heads
self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False)
self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.in_proj = args.linear_type(args.dim, 3 * args.n_heads * self.head_dim, bias=False, device=args.te_device)
self.out_proj = args.linear_type(args.n_heads * self.head_dim, args.dim, bias=False, device=args.te_device)
self.pos_embed = get_pos_embed(args)
self.attn_fn = args.attn_func
self.apply_qk_norm = args.apply_qk_norm
Expand All @@ -130,6 +150,7 @@ def __init__(self, layer_id, args: Params):
args.norm_type(
args.n_heads * self.head_dim,
eps=args.norm_eps,
device=args.te_device,
)
if self.apply_qk_norm
else nn.Identity()
Expand All @@ -138,6 +159,7 @@ def __init__(self, layer_id, args: Params):
args.norm_type(
args.n_heads * self.head_dim,
eps=args.norm_eps,
device=args.te_device,
)
if self.apply_qk_norm
else nn.Identity()
Expand Down Expand Up @@ -195,13 +217,13 @@ class GemmaMLP(nn.Module):
Modified from https://github.com/google/gemma_pytorch/blob/01062c9ef4cf89ac0c985b25a734164ede017d0b/gemma/model.py#L182-L201
"""

def __init__(self, dim: int, hidden_dim: int, layer_id: int):
def __init__(self, dim: int, hidden_dim: int, layer_id: int, args: Params):
super().__init__()
self.dim = dim
self.hidden_dim = hidden_dim
self.gate_proj = nn.Linear(dim, hidden_dim)
self.up_proj = nn.Linear(dim, hidden_dim)
self.down_proj = nn.Linear(hidden_dim, dim)
self.gate_proj = nn.Linear(dim, hidden_dim, device=args.te_device)
self.up_proj = nn.Linear(dim, hidden_dim, device=args.te_device)
self.down_proj = nn.Linear(hidden_dim, dim, device=args.te_device)
self._layer_id = layer_id

def forward(self, x):
Expand All @@ -225,10 +247,10 @@ def reset_parameters(self):
# Same as pseudocode provided from xformers SwiGLU
# https://github.com/facebookresearch/xformers
class SwiGLUTorch(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, bias=True):
def __init__(self, in_dim, hidden_dim, out_dim, args: Params = Params, bias=True):
super().__init__()
self.w12 = nn.Linear(in_dim, 2 * hidden_dim, bias=bias)
self.w3 = nn.Linear(hidden_dim, out_dim, bias=bias)
self.w12 = nn.Linear(in_dim, 2 * hidden_dim, bias=bias, device=args.te_device)
self.w3 = nn.Linear(hidden_dim, out_dim, bias=bias, device=args.te_device)

def forward(self, x):
gate, x = self.w12(x).chunk(2, dim=-1)
Expand All @@ -252,17 +274,17 @@ def __init__(self, layer_id, args: Params):
elif args.ffn_type == "swiglu_torch":
# this follows llama / lit llama -- go to multiple of 256
self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256)
self.feed_forward = SwiGLUTorch(args.dim, self.hidden_dim, args.dim, bias=False)
self.feed_forward = SwiGLUTorch(args.dim, self.hidden_dim, args.dim, args, bias=False)
elif args.ffn_type == "gelu":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also support fp8 for swiglu above? We can make a copy of the Swiglu class in this file. Here's the source for Swiglu https://github.com/facebookresearch/xformers/blob/7f8c290183344343771f4e1d945a8ce10a9500ff/xformers/ops/swiglu_op.py#L430

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rams16592 seems like the recursive replace linear patten should take care of this automatically. a function like this seems like it would be great and we can exclude certain linears that need to be higher precision for stability. this function has an include field instead of exclude, but hopefully that's easy to flip:
https://github.com/mlfoundations/open_clip/blob/73fa7f03a33da53653f61841eb6d69aef161e521/src/open_clip/utils.py#L65

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied this change in the latest commit. Excluding the last output Linear layer from the conversion to TE Linear as its running into errors.

# Follows mosaic mpt7b, but without a bias.
self.hidden_dim = args.dim * 4
self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False)
self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False)
self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False, device=args.te_device)
self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False, device=args.te_device)
self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2)
elif args.ffn_type == "gemma_geglu":
# this follows llama / lit llama -- go to multiple of 256
self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256)
self.feed_forward = GemmaMLP(args.dim, self.hidden_dim, layer_id)
self.feed_forward = GemmaMLP(args.dim, self.hidden_dim, layer_id, args)
elif args.ffn_type == "moe":
moe_args = MoEArgs(
hidden_size=args.dim,
Expand All @@ -283,10 +305,12 @@ def __init__(self, layer_id, args: Params):
self.attention_norm = args.norm_type(
args.dim,
eps=args.norm_eps,
device=args.te_device,
)
self.ffn_norm = args.norm_type(
args.dim,
eps=args.norm_eps,
device=args.te_device,
)
self.attention.seq_len = args.seq_len
self.reset_parameters()
Expand Down Expand Up @@ -455,9 +479,15 @@ def create_params(args):
vocab_size=cfg["vocab_size"],
post_embed_norm=cfg["post_embed_norm"],
weight_tying=cfg["weight_tying"],
norm_type=get_norm_class(cfg.get("model_norm", args.model_norm)),
norm_type=get_norm_class(cfg.get("model_norm", args.model_norm), args.use_fp8),
linear_type=LinearTE if (using_te and args.use_fp8) else nn.Linear,
te_device="cuda" if (using_te and args.use_fp8) else None,
attn_func=get_attn_func(
args.attn_name, args.attn_activation, args.attn_seq_scalar, args.attn_seq_scalar_alpha
args.attn_name,
args.attn_activation,
args.attn_seq_scalar,
args.attn_seq_scalar_alpha,
use_fp8=args.use_fp8,
),
apply_qk_norm=cfg.get("qk_norm", args.qk_norm),
positional_embedding_type=cfg.get("positional_embedding_type", args.positional_embedding_type),
Expand Down Expand Up @@ -495,10 +525,23 @@ def forward(self, x):
return out, None, None


def create_model(args):
def te_linear_ops(model, exclude_modules=["output"], tensor_parallel_group=None):
for name, module in model.named_children():
if len(list(module.children())) > 0:
te_linear_ops(module, exclude_modules, tensor_parallel_group)
if isinstance(module, te.Linear):
model._modules[name].set_tensor_parallel_group(tensor_parallel_group)
return model


def create_model(args, tensor_parallel_group=None):
if "mamba" in args.model:
model = Mamba(create_params(args))
if tensor_parallel_group is not None and using_te:
model = te_linear_ops(model.to(torch.bfloat16).cuda(), tensor_parallel_group)
return model
else:
model = Transformer(create_params(args))
if tensor_parallel_group is not None and using_te:
model = te_linear_ops(model.to(torch.bfloat16).cuda(), tensor_parallel_group)
return model
Loading
Loading