Skip to content

Commit

Permalink
Merge pull request #4 from YerevaNN/model_loading
Browse files Browse the repository at this point in the history
Add OPT model implementation, OPT model loading functionality from huggingface, and training OPT models with FSDP
  • Loading branch information
philippguevorguian authored Aug 24, 2024
2 parents 591b7dd + b08397a commit 21d8e10
Show file tree
Hide file tree
Showing 11 changed files with 596 additions and 40 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__pycache__
.idea
.vscode
.DS_Store
*.egg-info
build
Expand Down
64 changes: 32 additions & 32 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,38 +61,38 @@ def build_test_list():
requires_seed_checkpoint=True,
ngpu=4,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 1",
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
],
],
"PP 1D test 1f1b",
"pp_1f1b",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 1",
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
],
],
"PP 1D test gpipe",
"pp_gpipe",
requires_seed_checkpoint=True,
ngpu=2,
),
# OverrideDefinitions(
# [
# [
# "--checkpoint.enable_checkpoint",
# "--experimental.pipeline_parallel_degree 2",
# "--experimental.pipeline_parallel_split_points layers.4",
# "--experimental.pipeline_parallel_schedule 1f1b",
# "--training.data_parallel_degree 1",
# "--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
# ],
# ],
# "PP 1D test 1f1b",
# "pp_1f1b",
# requires_seed_checkpoint=True,
# ngpu=2,
# ),
# OverrideDefinitions(
# [
# [
# "--checkpoint.enable_checkpoint",
# "--experimental.pipeline_parallel_degree 2",
# "--experimental.pipeline_parallel_split_points layers.4",
# "--experimental.pipeline_parallel_schedule gpipe",
# "--training.data_parallel_degree 1",
# "--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
# ],
# ],
# "PP 1D test gpipe",
# "pp_gpipe",
# requires_seed_checkpoint=True,
# ngpu=2,
# ),
OverrideDefinitions(
[
[
Expand Down
13 changes: 12 additions & 1 deletion torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,26 @@
# LICENSE file in the root directory of this source tree.

from torchtitan.models.llama import llama2_configs, llama3_configs, Transformer
from torchtitan.models.opt import opt_configs, OPT, load_opt_weights

models_config = {
"llama2": llama2_configs,
"llama3": llama3_configs,
"opt": opt_configs
}

model_name_to_cls = {"llama2": Transformer, "llama3": Transformer}
model_name_to_cls = {
"llama2": Transformer,
"llama3": Transformer,
"opt": OPT
}

model_name_to_tokenizer = {
"llama2": "sentencepiece",
"llama3": "tiktoken",
"opt": "tiktoken"
}

model_name_to_weights_loading_fns = {
"opt": load_opt_weights
}
2 changes: 2 additions & 0 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
return nn.LayerNorm(dim, eps=eps, bias=False)
elif norm_type == "np_layernorm":
return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
elif norm_type == "layernorm_bias":
return nn.LayerNorm(dim, eps=eps, bias=True)
elif norm_type == "rmsnorm":
return RMSNorm(dim, eps=eps)
elif norm_type == "compiled_rmsnorm":
Expand Down
20 changes: 20 additions & 0 deletions torchtitan/models/opt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# <model name> is licensed under the <license name>,
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.models.opt.model import ModelArgs, OPT
from torchtitan.models.opt.utils import load_opt_weights

__all__ = ["OPT", "load_opt_weights"]

opt_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=8),
"125M": ModelArgs(dim=768, n_layers=12, n_heads=12),
# "1.3B": ModelArgs(dim=2048, n_layers=, n_heads=8),
# "6.7B": ModelArgs(dim=2048, n_layers=, n_heads=8)
}
Loading

0 comments on commit 21d8e10

Please sign in to comment.