forked from pytorch/torchtitan
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #29 from YerevaNN/model_loading
Model loading
- Loading branch information
Showing
8 changed files
with
259 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from transformers import AutoModelForCausalLM | ||
import torch | ||
from torchtitan.models.llama import Transformer | ||
from torchtitan.logging import logger | ||
|
||
|
||
# reverse_permute for sliced rotary | ||
def reverse_permute(w, n_heads, dim1, dim2): | ||
return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) | ||
|
||
|
||
# permute for sliced rotary | ||
def permute(w, n_heads, dim1, dim2): | ||
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) | ||
|
||
|
||
def get_hf_llama3_state_dict_keys_mapping(num_layers: int): | ||
""" | ||
Get a mapping between state dict keys of different implementations. | ||
Args: | ||
num_layers (int): number of transformer layers (blocks). | ||
Returns: | ||
dict: mapping between local implementation state dict keys and hf implementation state dict keys | ||
""" | ||
keys_mapping = { | ||
'tok_embeddings.weight': 'model.embed_tokens.weight', | ||
# add layer weight mappings here | ||
'norm.weight': 'model.norm.weight', | ||
# "output.weight": 'lm_head.weight', | ||
} | ||
for layer in range(num_layers): | ||
keys_mapping.update({ | ||
f'layers.{layer}.attention.wq.weight': f'model.layers.{layer}.self_attn.q_proj.weight', | ||
f'layers.{layer}.attention.wk.weight': f'model.layers.{layer}.self_attn.k_proj.weight', | ||
f'layers.{layer}.attention.wv.weight': f'model.layers.{layer}.self_attn.v_proj.weight', | ||
f'layers.{layer}.attention.wo.weight': f'model.layers.{layer}.self_attn.o_proj.weight', | ||
f'layers.{layer}.feed_forward.w1.weight': f'model.layers.{layer}.mlp.gate_proj.weight', | ||
f'layers.{layer}.feed_forward.w3.weight': f'model.layers.{layer}.mlp.up_proj.weight', | ||
f'layers.{layer}.feed_forward.w2.weight': f'model.layers.{layer}.mlp.down_proj.weight', | ||
f'layers.{layer}.attention_norm.weight': f'model.layers.{layer}.input_layernorm.weight', | ||
f'layers.{layer}.ffn_norm.weight': f'model.layers.{layer}.post_attention_layernorm.weight' | ||
}) | ||
|
||
return keys_mapping | ||
|
||
|
||
def download_llama3_weights(model: Transformer, weights_path: str, source: str, token_embedding_size: int): | ||
""" | ||
write docs | ||
""" | ||
if source == "huggingface": | ||
hf_model = AutoModelForCausalLM.from_pretrained(weights_path) | ||
# hf_model.resize_token_embeddings(new_num_tokens=token_embedding_size) | ||
keys_mapping = get_hf_llama3_state_dict_keys_mapping(model.n_layers) | ||
hf_state_dict = hf_model.state_dict() | ||
corrected_state_dict = {} | ||
for key, value in keys_mapping.items(): | ||
assert hf_state_dict[value].shape == model.state_dict()[key].shape | ||
if "self_attn.q_proj.weight" in value: | ||
corrected_state_dict[key] = reverse_permute( | ||
hf_state_dict[value], model.model_args.n_heads, | ||
model.model_args.dim, model.model_args.dim | ||
) | ||
elif "self_attn.k_proj.weight" in value: | ||
kv_dim = model.model_args.dim // (model.model_args.n_heads // model.model_args.n_kv_heads) | ||
corrected_state_dict[key] = reverse_permute( | ||
hf_state_dict[value], model.model_args.n_kv_heads, | ||
kv_dim, model.model_args.dim | ||
) | ||
else: | ||
corrected_state_dict[key] = hf_state_dict[value] | ||
|
||
with torch.device(model.freqs_cis.device): | ||
corrected_state_dict["freqs_cis"] = model._precompute_freqs_cis() | ||
|
||
model.load_state_dict(corrected_state_dict) | ||
logger.info("Successfully loaded Llama 3 model to the titan model.") | ||
|
||
# from transformers import AutoTokenizer | ||
# tok = AutoTokenizer.from_pretrained(weights_path) | ||
# device = "cuda" | ||
# hf_model.to(device) | ||
# model.eval() | ||
# text = "Hello world" | ||
# data = tok(text, return_tensors="pt").to(device) | ||
# hf_logits = hf_model(**data).logits | ||
# logits = model(data.input_ids) | ||
# print(torch.allclose(hf_logits, logits, atol=1e-4)) | ||
else: | ||
raise NotImplemented |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# torchtitan Config.toml | ||
|
||
[job] | ||
dump_folder = "/nfs/dgx/raid/chem/titan_outputs" | ||
description = "Llama 3.2 training" | ||
use_for_integration_test = false | ||
|
||
[profiling] | ||
enable_profiling = false | ||
save_traces_folder = "profile_trace" | ||
profile_freq = 10 | ||
enable_memory_snapshot = false | ||
save_memory_snapshot_folder = "memory_snapshot" | ||
|
||
[metrics] | ||
log_freq = 1 | ||
enable_color_printing = true | ||
enable_aim = false | ||
save_aim_folder = "aim" | ||
|
||
[model] | ||
name = "llama3" | ||
flavor = "1B" | ||
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm | ||
tokenizer_path = "meta-llama/Llama-3.2-1B" | ||
|
||
[optimizer] | ||
name = "AdamW" | ||
lr = 1.0e-4 | ||
|
||
[training] | ||
batch_size = 10 | ||
gradient_accumulation_steps = 1 | ||
seq_len = 2048 | ||
warmup_steps = 500 # lr scheduler warm up, normally 20% of the train steps | ||
max_norm = 1.0 # grad norm clipping | ||
steps = 30 | ||
data_parallel_degree = -1 | ||
tensor_parallel_degree = 1 | ||
compile = true | ||
# dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) | ||
# dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M), chemlactica_train_mini (4K) | ||
dataset = "chemlactica_train" | ||
data_process_style="chemlactica_style" | ||
|
||
[dataloader] | ||
num_workers = 4 | ||
|
||
[experimental] | ||
pipeline_parallel_degree = 1 | ||
enable_async_tensor_parallel = false | ||
|
||
[checkpoint] | ||
enable_checkpoint = true | ||
load_folder = "meta-llama/Llama-3.2-1B" | ||
save_folder = "yerevann/Llama-3.2-1B" | ||
interval_type = "steps" | ||
interval = 1000 | ||
model_weights_only = false | ||
export_dtype = "float32" | ||
async_mode = "async_with_pinned_mem" # ["disabled", "async", "async_with_pinned_mem"] | ||
|
||
[activation_checkpoint] | ||
mode = 'none' # ['none', 'selective', 'full'] | ||
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy | ||
|
||
[float8] | ||
enable_float8_linear = false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# torchtitan Config.toml | ||
|
||
[job] | ||
dump_folder = "/nfs/dgx/raid/chem/titan_outputs" | ||
description = "Llama 3.2 training" | ||
use_for_integration_test = false | ||
|
||
[profiling] | ||
enable_profiling = false | ||
save_traces_folder = "profile_trace" | ||
profile_freq = 10 | ||
enable_memory_snapshot = false | ||
save_memory_snapshot_folder = "memory_snapshot" | ||
|
||
[metrics] | ||
log_freq = 1 | ||
enable_color_printing = true | ||
enable_aim = false | ||
save_aim_folder = "aim" | ||
|
||
[model] | ||
name = "llama3" | ||
flavor = "1B" | ||
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm | ||
# test tokenizer.model, for debug purpose only | ||
# tokenizer_path = "./test/assets/test_tiktoken.model" | ||
tokenizer_path = "meta-llama/Llama-3.2-1B" | ||
|
||
[optimizer] | ||
name = "AdamW" | ||
lr = 1.0e-4 | ||
|
||
[training] | ||
batch_size = 1 | ||
gradient_accumulation_steps = 3 | ||
seq_len = 2048 | ||
warmup_steps = 500 # lr scheduler warm up, normally 20% of the train steps | ||
max_norm = 1.0 # grad norm clipping | ||
steps = 10 | ||
data_parallel_degree = -1 | ||
tensor_parallel_degree = 1 | ||
compile = true | ||
# dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) | ||
# dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M), chemlactica_train_mini (4K) | ||
dataset = "chemlactica_train" | ||
data_process_style="chemlactica_style" | ||
|
||
[experimental] | ||
pipeline_parallel_degree = 1 | ||
enable_async_tensor_parallel = false | ||
|
||
[checkpoint] | ||
enable_checkpoint = true | ||
load_folder = "meta-llama/Llama-3.2-1B" | ||
save_folder = "meta-llama/Llama-3.2-1B" | ||
interval_type = "steps" | ||
interval = 1000 | ||
model_weights_only = false | ||
export_dtype = "float32" | ||
async_mode = "async_with_pinned_mem" # ["disabled", "async", "async_with_pinned_mem"] | ||
|
||
[model_download_export] | ||
to_titan = true | ||
weights_source = "huggingface" | ||
# to_hf = true | ||
|
||
[activation_checkpoint] | ||
mode = 'none' # ['none', 'selective', 'full'] | ||
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy | ||
|
||
[float8] | ||
enable_float8_linear = false |