Skip to content

Commit

Permalink
Merge pull request #29 from YerevaNN/model_loading
Browse files Browse the repository at this point in the history
Model loading
  • Loading branch information
tigranfah authored Oct 2, 2024
2 parents 225c78c + 57a906d commit 4458b78
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 8 deletions.
4 changes: 2 additions & 2 deletions submitit_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
jobs = []
with executor.batch():
for _ in range(1):
train_config = './train_configs/chemlactica_125m.toml'
# train_config = './train_configs/chemlactica_125m.toml'
# train_config = './train_configs/chemlactica_1.3b.toml'
# train_config = './train_configs/llama3_8b.toml'
train_config = './train_configs/llama3.2_1b.toml'
# train_config = './train_configs/debug_model.toml'
function = submitit.helpers.CommandFunction([
'python3', '-m', 'torch.distributed.run',
Expand Down
5 changes: 3 additions & 2 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.models.llama import llama2_configs, llama3_configs, Transformer
from torchtitan.models.llama import llama2_configs, llama3_configs, Transformer, download_llama3_weights
from torchtitan.models.opt import opt_configs, OPT, download_opt_weights, export_opt_weights

models_config = {
Expand All @@ -26,7 +26,8 @@
}

model_name_to_weights_download_fns = {
"opt": download_opt_weights
"opt": download_opt_weights,
"llama3": download_llama3_weights
}

model_name_to_weights_export_fns = {
Expand Down
13 changes: 11 additions & 2 deletions torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.models.llama.model import ModelArgs, Transformer
from torchtitan.models.llama.utils import download_llama3_weights

__all__ = ["Transformer"]
__all__ = ["Transformer", download_llama3_weights]

llama2_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16),
Expand All @@ -30,6 +31,14 @@

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
"1B": ModelArgs(
dim=2048,
n_layers=16,
n_heads=32,
n_kv_heads=8,
rope_theta=500000,
share_embeddings=True
),
"8B": ModelArgs(
dim=4096,
n_layers=32,
Expand Down Expand Up @@ -57,4 +66,4 @@
multiple_of=4096,
rope_theta=500000,
),
}
}
10 changes: 8 additions & 2 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ModelArgs:
# `False`, each uses the total number of transformer blocks
depth_init: bool = True
norm_type: str = "rmsnorm"
share_embeddings: bool = False


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
Expand Down Expand Up @@ -241,7 +242,7 @@ def __init__(
ffn_dim_multiplier: Optional[float],
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
Expand Down Expand Up @@ -377,7 +378,10 @@ def __init__(self, model_args: ModelArgs):
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)

self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
self.output = None
if not self.model_args.share_embeddings:
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)

self.init_weights()

def init_weights(self):
Expand Down Expand Up @@ -439,6 +443,8 @@ def forward(self, tokens: torch.Tensor):
h = layer(h, self.freqs_cis)

h = self.norm(h) if self.norm else h
if self.model_args.share_embeddings:
return torch.matmul(h, self.tok_embeddings.weight.t()).float()
output = self.output(h).float() if self.output else h
return output

Expand Down
93 changes: 93 additions & 0 deletions torchtitan/models/llama/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from transformers import AutoModelForCausalLM
import torch
from torchtitan.models.llama import Transformer
from torchtitan.logging import logger


# reverse_permute for sliced rotary
def reverse_permute(w, n_heads, dim1, dim2):
return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)


# permute for sliced rotary
def permute(w, n_heads, dim1, dim2):
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)


def get_hf_llama3_state_dict_keys_mapping(num_layers: int):
"""
Get a mapping between state dict keys of different implementations.
Args:
num_layers (int): number of transformer layers (blocks).
Returns:
dict: mapping between local implementation state dict keys and hf implementation state dict keys
"""
keys_mapping = {
'tok_embeddings.weight': 'model.embed_tokens.weight',
# add layer weight mappings here
'norm.weight': 'model.norm.weight',
# "output.weight": 'lm_head.weight',
}
for layer in range(num_layers):
keys_mapping.update({
f'layers.{layer}.attention.wq.weight': f'model.layers.{layer}.self_attn.q_proj.weight',
f'layers.{layer}.attention.wk.weight': f'model.layers.{layer}.self_attn.k_proj.weight',
f'layers.{layer}.attention.wv.weight': f'model.layers.{layer}.self_attn.v_proj.weight',
f'layers.{layer}.attention.wo.weight': f'model.layers.{layer}.self_attn.o_proj.weight',
f'layers.{layer}.feed_forward.w1.weight': f'model.layers.{layer}.mlp.gate_proj.weight',
f'layers.{layer}.feed_forward.w3.weight': f'model.layers.{layer}.mlp.up_proj.weight',
f'layers.{layer}.feed_forward.w2.weight': f'model.layers.{layer}.mlp.down_proj.weight',
f'layers.{layer}.attention_norm.weight': f'model.layers.{layer}.input_layernorm.weight',
f'layers.{layer}.ffn_norm.weight': f'model.layers.{layer}.post_attention_layernorm.weight'
})

return keys_mapping


def download_llama3_weights(model: Transformer, weights_path: str, source: str, token_embedding_size: int):
"""
write docs
"""
if source == "huggingface":
hf_model = AutoModelForCausalLM.from_pretrained(weights_path)
# hf_model.resize_token_embeddings(new_num_tokens=token_embedding_size)
keys_mapping = get_hf_llama3_state_dict_keys_mapping(model.n_layers)
hf_state_dict = hf_model.state_dict()
corrected_state_dict = {}
for key, value in keys_mapping.items():
assert hf_state_dict[value].shape == model.state_dict()[key].shape
if "self_attn.q_proj.weight" in value:
corrected_state_dict[key] = reverse_permute(
hf_state_dict[value], model.model_args.n_heads,
model.model_args.dim, model.model_args.dim
)
elif "self_attn.k_proj.weight" in value:
kv_dim = model.model_args.dim // (model.model_args.n_heads // model.model_args.n_kv_heads)
corrected_state_dict[key] = reverse_permute(
hf_state_dict[value], model.model_args.n_kv_heads,
kv_dim, model.model_args.dim
)
else:
corrected_state_dict[key] = hf_state_dict[value]

with torch.device(model.freqs_cis.device):
corrected_state_dict["freqs_cis"] = model._precompute_freqs_cis()

model.load_state_dict(corrected_state_dict)
logger.info("Successfully loaded Llama 3 model to the titan model.")

# from transformers import AutoTokenizer
# tok = AutoTokenizer.from_pretrained(weights_path)
# device = "cuda"
# hf_model.to(device)
# model.eval()
# text = "Hello world"
# data = tok(text, return_tensors="pt").to(device)
# hf_logits = hf_model(**data).logits
# logits = model(data.input_ids)
# print(torch.allclose(hf_logits, logits, atol=1e-4))
else:
raise NotImplemented
2 changes: 2 additions & 0 deletions torchtitan/tokenizers/tokenizer/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,6 @@ def n_words(self) -> int:

@property
def padded_n_words(self):
if self.n_words % self.pad_to_multiple_of == 0:
return self.n_words
return self._n_words + self.pad_to_multiple_of - self._n_words % self.pad_to_multiple_of
68 changes: 68 additions & 0 deletions train_configs/llama3.2_1b.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# torchtitan Config.toml

[job]
dump_folder = "/nfs/dgx/raid/chem/titan_outputs"
description = "Llama 3.2 training"
use_for_integration_test = false

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 1
enable_color_printing = true
enable_aim = false
save_aim_folder = "aim"

[model]
name = "llama3"
flavor = "1B"
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
tokenizer_path = "meta-llama/Llama-3.2-1B"

[optimizer]
name = "AdamW"
lr = 1.0e-4

[training]
batch_size = 10
gradient_accumulation_steps = 1
seq_len = 2048
warmup_steps = 500 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 30
data_parallel_degree = -1
tensor_parallel_degree = 1
compile = true
# dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
# dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M), chemlactica_train_mini (4K)
dataset = "chemlactica_train"
data_process_style="chemlactica_style"

[dataloader]
num_workers = 4

[experimental]
pipeline_parallel_degree = 1
enable_async_tensor_parallel = false

[checkpoint]
enable_checkpoint = true
load_folder = "meta-llama/Llama-3.2-1B"
save_folder = "yerevann/Llama-3.2-1B"
interval_type = "steps"
interval = 1000
model_weights_only = false
export_dtype = "float32"
async_mode = "async_with_pinned_mem" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'none' # ['none', 'selective', 'full']
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[float8]
enable_float8_linear = false
72 changes: 72 additions & 0 deletions train_configs/llama3.2_1b_conversion.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# torchtitan Config.toml

[job]
dump_folder = "/nfs/dgx/raid/chem/titan_outputs"
description = "Llama 3.2 training"
use_for_integration_test = false

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 1
enable_color_printing = true
enable_aim = false
save_aim_folder = "aim"

[model]
name = "llama3"
flavor = "1B"
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
# test tokenizer.model, for debug purpose only
# tokenizer_path = "./test/assets/test_tiktoken.model"
tokenizer_path = "meta-llama/Llama-3.2-1B"

[optimizer]
name = "AdamW"
lr = 1.0e-4

[training]
batch_size = 1
gradient_accumulation_steps = 3
seq_len = 2048
warmup_steps = 500 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
tensor_parallel_degree = 1
compile = true
# dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
# dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M), chemlactica_train_mini (4K)
dataset = "chemlactica_train"
data_process_style="chemlactica_style"

[experimental]
pipeline_parallel_degree = 1
enable_async_tensor_parallel = false

[checkpoint]
enable_checkpoint = true
load_folder = "meta-llama/Llama-3.2-1B"
save_folder = "meta-llama/Llama-3.2-1B"
interval_type = "steps"
interval = 1000
model_weights_only = false
export_dtype = "float32"
async_mode = "async_with_pinned_mem" # ["disabled", "async", "async_with_pinned_mem"]

[model_download_export]
to_titan = true
weights_source = "huggingface"
# to_hf = true

[activation_checkpoint]
mode = 'none' # ['none', 'selective', 'full']
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[float8]
enable_float8_linear = false

0 comments on commit 4458b78

Please sign in to comment.