diff --git a/submitit_train.py b/submitit_train.py index 8c4c597b..bb7a0190 100644 --- a/submitit_train.py +++ b/submitit_train.py @@ -19,9 +19,9 @@ jobs = [] with executor.batch(): for _ in range(1): - train_config = './train_configs/chemlactica_125m.toml' + # train_config = './train_configs/chemlactica_125m.toml' # train_config = './train_configs/chemlactica_1.3b.toml' - # train_config = './train_configs/llama3_8b.toml' + train_config = './train_configs/llama3.2_1b.toml' # train_config = './train_configs/debug_model.toml' function = submitit.helpers.CommandFunction([ 'python3', '-m', 'torch.distributed.run', diff --git a/torchtitan/models/__init__.py b/torchtitan/models/__init__.py index 354875de..befffbe9 100644 --- a/torchtitan/models/__init__.py +++ b/torchtitan/models/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtitan.models.llama import llama2_configs, llama3_configs, Transformer +from torchtitan.models.llama import llama2_configs, llama3_configs, Transformer, download_llama3_weights from torchtitan.models.opt import opt_configs, OPT, download_opt_weights, export_opt_weights models_config = { @@ -26,7 +26,8 @@ } model_name_to_weights_download_fns = { - "opt": download_opt_weights + "opt": download_opt_weights, + "llama3": download_llama3_weights } model_name_to_weights_export_fns = { diff --git a/torchtitan/models/llama/__init__.py b/torchtitan/models/llama/__init__.py index 887a96cd..20cb947c 100644 --- a/torchtitan/models/llama/__init__.py +++ b/torchtitan/models/llama/__init__.py @@ -8,8 +8,9 @@ # Copyright (c) Meta Platforms, Inc. All Rights Reserved. from torchtitan.models.llama.model import ModelArgs, Transformer +from torchtitan.models.llama.utils import download_llama3_weights -__all__ = ["Transformer"] +__all__ = ["Transformer", download_llama3_weights] llama2_configs = { "debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16), @@ -30,6 +31,14 @@ llama3_configs = { "debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000), + "1B": ModelArgs( + dim=2048, + n_layers=16, + n_heads=32, + n_kv_heads=8, + rope_theta=500000, + share_embeddings=True + ), "8B": ModelArgs( dim=4096, n_layers=32, @@ -57,4 +66,4 @@ multiple_of=4096, rope_theta=500000, ), -} +} \ No newline at end of file diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 2060519a..4845586e 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -35,6 +35,7 @@ class ModelArgs: # `False`, each uses the total number of transformer blocks depth_init: bool = True norm_type: str = "rmsnorm" + share_embeddings: bool = False def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: @@ -241,7 +242,7 @@ def __init__( ffn_dim_multiplier: Optional[float], ): super().__init__() - hidden_dim = int(2 * hidden_dim / 3) + # hidden_dim = int(2 * hidden_dim / 3) # custom dim factor multiplier if ffn_dim_multiplier is not None: hidden_dim = int(ffn_dim_multiplier * hidden_dim) @@ -377,7 +378,10 @@ def __init__(self, model_args: ModelArgs): model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps ) - self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + self.output = None + if not self.model_args.share_embeddings: + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + self.init_weights() def init_weights(self): @@ -439,6 +443,8 @@ def forward(self, tokens: torch.Tensor): h = layer(h, self.freqs_cis) h = self.norm(h) if self.norm else h + if self.model_args.share_embeddings: + return torch.matmul(h, self.tok_embeddings.weight.t()).float() output = self.output(h).float() if self.output else h return output diff --git a/torchtitan/models/llama/utils.py b/torchtitan/models/llama/utils.py new file mode 100644 index 00000000..504081d7 --- /dev/null +++ b/torchtitan/models/llama/utils.py @@ -0,0 +1,93 @@ +from transformers import AutoModelForCausalLM +import torch +from torchtitan.models.llama import Transformer +from torchtitan.logging import logger + + +# reverse_permute for sliced rotary +def reverse_permute(w, n_heads, dim1, dim2): + return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + +# permute for sliced rotary +def permute(w, n_heads, dim1, dim2): + return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + +def get_hf_llama3_state_dict_keys_mapping(num_layers: int): + """ + Get a mapping between state dict keys of different implementations. + + Args: + num_layers (int): number of transformer layers (blocks). + + Returns: + dict: mapping between local implementation state dict keys and hf implementation state dict keys + + """ + keys_mapping = { + 'tok_embeddings.weight': 'model.embed_tokens.weight', + # add layer weight mappings here + 'norm.weight': 'model.norm.weight', + # "output.weight": 'lm_head.weight', + } + for layer in range(num_layers): + keys_mapping.update({ + f'layers.{layer}.attention.wq.weight': f'model.layers.{layer}.self_attn.q_proj.weight', + f'layers.{layer}.attention.wk.weight': f'model.layers.{layer}.self_attn.k_proj.weight', + f'layers.{layer}.attention.wv.weight': f'model.layers.{layer}.self_attn.v_proj.weight', + f'layers.{layer}.attention.wo.weight': f'model.layers.{layer}.self_attn.o_proj.weight', + f'layers.{layer}.feed_forward.w1.weight': f'model.layers.{layer}.mlp.gate_proj.weight', + f'layers.{layer}.feed_forward.w3.weight': f'model.layers.{layer}.mlp.up_proj.weight', + f'layers.{layer}.feed_forward.w2.weight': f'model.layers.{layer}.mlp.down_proj.weight', + f'layers.{layer}.attention_norm.weight': f'model.layers.{layer}.input_layernorm.weight', + f'layers.{layer}.ffn_norm.weight': f'model.layers.{layer}.post_attention_layernorm.weight' + }) + + return keys_mapping + + +def download_llama3_weights(model: Transformer, weights_path: str, source: str, token_embedding_size: int): + """ + write docs + """ + if source == "huggingface": + hf_model = AutoModelForCausalLM.from_pretrained(weights_path) + # hf_model.resize_token_embeddings(new_num_tokens=token_embedding_size) + keys_mapping = get_hf_llama3_state_dict_keys_mapping(model.n_layers) + hf_state_dict = hf_model.state_dict() + corrected_state_dict = {} + for key, value in keys_mapping.items(): + assert hf_state_dict[value].shape == model.state_dict()[key].shape + if "self_attn.q_proj.weight" in value: + corrected_state_dict[key] = reverse_permute( + hf_state_dict[value], model.model_args.n_heads, + model.model_args.dim, model.model_args.dim + ) + elif "self_attn.k_proj.weight" in value: + kv_dim = model.model_args.dim // (model.model_args.n_heads // model.model_args.n_kv_heads) + corrected_state_dict[key] = reverse_permute( + hf_state_dict[value], model.model_args.n_kv_heads, + kv_dim, model.model_args.dim + ) + else: + corrected_state_dict[key] = hf_state_dict[value] + + with torch.device(model.freqs_cis.device): + corrected_state_dict["freqs_cis"] = model._precompute_freqs_cis() + + model.load_state_dict(corrected_state_dict) + logger.info("Successfully loaded Llama 3 model to the titan model.") + + # from transformers import AutoTokenizer + # tok = AutoTokenizer.from_pretrained(weights_path) + # device = "cuda" + # hf_model.to(device) + # model.eval() + # text = "Hello world" + # data = tok(text, return_tensors="pt").to(device) + # hf_logits = hf_model(**data).logits + # logits = model(data.input_ids) + # print(torch.allclose(hf_logits, logits, atol=1e-4)) + else: + raise NotImplemented diff --git a/torchtitan/tokenizers/tokenizer/custom.py b/torchtitan/tokenizers/tokenizer/custom.py index 7f38deb8..4ce873f6 100644 --- a/torchtitan/tokenizers/tokenizer/custom.py +++ b/torchtitan/tokenizers/tokenizer/custom.py @@ -81,4 +81,6 @@ def n_words(self) -> int: @property def padded_n_words(self): + if self.n_words % self.pad_to_multiple_of == 0: + return self.n_words return self._n_words + self.pad_to_multiple_of - self._n_words % self.pad_to_multiple_of \ No newline at end of file diff --git a/train_configs/llama3.2_1b.toml b/train_configs/llama3.2_1b.toml new file mode 100644 index 00000000..da7ea32a --- /dev/null +++ b/train_configs/llama3.2_1b.toml @@ -0,0 +1,68 @@ +# torchtitan Config.toml + +[job] +dump_folder = "/nfs/dgx/raid/chem/titan_outputs" +description = "Llama 3.2 training" +use_for_integration_test = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +enable_color_printing = true +enable_aim = false +save_aim_folder = "aim" + +[model] +name = "llama3" +flavor = "1B" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm +tokenizer_path = "meta-llama/Llama-3.2-1B" + +[optimizer] +name = "AdamW" +lr = 1.0e-4 + +[training] +batch_size = 10 +gradient_accumulation_steps = 1 +seq_len = 2048 +warmup_steps = 500 # lr scheduler warm up, normally 20% of the train steps +max_norm = 1.0 # grad norm clipping +steps = 30 +data_parallel_degree = -1 +tensor_parallel_degree = 1 +compile = true +# dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) +# dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M), chemlactica_train_mini (4K) +dataset = "chemlactica_train" +data_process_style="chemlactica_style" + +[dataloader] +num_workers = 4 + +[experimental] +pipeline_parallel_degree = 1 +enable_async_tensor_parallel = false + +[checkpoint] +enable_checkpoint = true +load_folder = "meta-llama/Llama-3.2-1B" +save_folder = "yerevann/Llama-3.2-1B" +interval_type = "steps" +interval = 1000 +model_weights_only = false +export_dtype = "float32" +async_mode = "async_with_pinned_mem" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'none' # ['none', 'selective', 'full'] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_float8_linear = false diff --git a/train_configs/llama3.2_1b_conversion.toml b/train_configs/llama3.2_1b_conversion.toml new file mode 100644 index 00000000..87cbedc5 --- /dev/null +++ b/train_configs/llama3.2_1b_conversion.toml @@ -0,0 +1,72 @@ +# torchtitan Config.toml + +[job] +dump_folder = "/nfs/dgx/raid/chem/titan_outputs" +description = "Llama 3.2 training" +use_for_integration_test = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +enable_color_printing = true +enable_aim = false +save_aim_folder = "aim" + +[model] +name = "llama3" +flavor = "1B" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm +# test tokenizer.model, for debug purpose only +# tokenizer_path = "./test/assets/test_tiktoken.model" +tokenizer_path = "meta-llama/Llama-3.2-1B" + +[optimizer] +name = "AdamW" +lr = 1.0e-4 + +[training] +batch_size = 1 +gradient_accumulation_steps = 3 +seq_len = 2048 +warmup_steps = 500 # lr scheduler warm up, normally 20% of the train steps +max_norm = 1.0 # grad norm clipping +steps = 10 +data_parallel_degree = -1 +tensor_parallel_degree = 1 +compile = true +# dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) +# dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M), chemlactica_train_mini (4K) +dataset = "chemlactica_train" +data_process_style="chemlactica_style" + +[experimental] +pipeline_parallel_degree = 1 +enable_async_tensor_parallel = false + +[checkpoint] +enable_checkpoint = true +load_folder = "meta-llama/Llama-3.2-1B" +save_folder = "meta-llama/Llama-3.2-1B" +interval_type = "steps" +interval = 1000 +model_weights_only = false +export_dtype = "float32" +async_mode = "async_with_pinned_mem" # ["disabled", "async", "async_with_pinned_mem"] + +[model_download_export] +to_titan = true +weights_source = "huggingface" +# to_hf = true + +[activation_checkpoint] +mode = 'none' # ['none', 'selective', 'full'] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_float8_linear = false