Skip to content

Commit

Permalink
Merge pull request #31 from YerevaNN/model_loading
Browse files Browse the repository at this point in the history
Model loading
  • Loading branch information
MenuaB authored Oct 19, 2024
2 parents 4458b78 + 7b97d4d commit 8e38015
Show file tree
Hide file tree
Showing 26 changed files with 2,509,133 additions and 185 deletions.
4 changes: 2 additions & 2 deletions submitit_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

if __name__ == "__main__":
executor = submitit.AutoExecutor(folder="~/slurm_jobs/titan/job_%j")
n_gpus = 4
n_gpus = 8
executor.update_parameters(
name="titan", timeout_min=3 * 24 * 60,
gpus_per_node=n_gpus,
nodes=1, mem_gb=80, cpus_per_task=n_gpus * 4,
slurm_additional_parameters={
"partition": "a100"
"partition": "h100"
}
)

Expand Down
49 changes: 49 additions & 0 deletions submitit_train_hparam_tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import submitit
import datetime
import yaml
import os


if __name__ == "__main__":
executor = submitit.AutoExecutor(folder="~/slurm_jobs/titan/job_%j")
n_gpus = 8
executor.update_parameters(
name="titan", timeout_min=3 * 60,
gpus_per_node=n_gpus,
nodes=1, mem_gb=80, cpus_per_task=n_gpus * 4,
slurm_additional_parameters={
"partition": "h100"
}
)

hparams = {
# "optimizer.lr": ["1.2e-3", "9e-4", "6e-4", "3e-4"],
# "optimizer.lr": ["8e-4", "6e-4", "4e-4", "2e-4"],
# "optimizer.lr": ["2.5e-4"],
# "optimizer.lr": ["1e-4", "8e-5", "6e-5", "4e-5", "2e-5"],
}

jobs = []
with executor.batch():
for _ in range(1):
for hparam_name, value in hparams.items():
for v in value:
# train_config = './train_configs/chemlactica_125m.toml'
# train_config = './train_configs/chemlactica_1.3b.toml'
train_config = './train_configs/llama3.2_1b.toml'
# train_config = './train_configs/debug_model.toml'
function = submitit.helpers.CommandFunction([
'python3', '-m', 'torch.distributed.run',
'--nproc_per_node', f'{n_gpus}',
'--rdzv_backend', 'c10d',
'--rdzv_endpoint', 'localhost:0',
'--local-ranks-filter', '0',
'--role', 'rank', '--tee', '3',
'train.py',
'--job.config_file', train_config,
f'--{hparam_name}', v
])
print(' '.join(function.command))
# subprocess.run(function.command)
job = executor.submit(function)
jobs.append(job)
6 changes: 5 additions & 1 deletion torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,11 @@ def __init__(
for idx, lr_scheduler in enumerate(lr_schedulers):
self.states[f"lr_scheduler_{idx}"] = lr_scheduler

self.save_folder = os.path.join(job_config.job.dump_folder, os.path.join(ckpt_config.save_folder, experiment_hash))

if job_config.model_download_export.to_hf or job_config.model_download_export.to_titan:
self.save_folder = os.path.join(job_config.job.dump_folder, ckpt_config.save_folder)
else:
self.save_folder = os.path.join(job_config.job.dump_folder, os.path.join(ckpt_config.save_folder, experiment_hash))
self.load_folder = os.path.join(job_config.job.dump_folder, ckpt_config.load_folder)
self.interval_type = (
IntervalType.SECONDS
Expand Down
8 changes: 8 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,14 @@ def __init__(self):
When enable_checkpoint is set to true, checkpoints will loaded from {--job.dump_folder}/{--checkpoint.load_folder}.
""",
)
self.parser.add_argument(
"--checkpoint.load_at_step",
type=int,
default=0,
help="""
The step to which to load.
""",
)
self.parser.add_argument(
"--checkpoint.save_folder",
type=str,
Expand Down
16 changes: 11 additions & 5 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,10 @@ def __init__(
dataset_files = glob.glob(os.path.join(dataset_path, "*.jsonl"))
ds = load_dataset("text", data_files=dataset_files, split="train", streaming="valid" not in dataset_name)

try:
data_processing_fn = _supported_data_processing_styles[data_processing_style]
except KeyError as e:
raise ValueError(f"Unsupported data processing style: {data_processing_style}")
# try:
data_processing_fn = _supported_data_processing_styles[data_processing_style]
# except KeyError as e:
# raise ValueError(f"Unsupported data processing style: {data_processing_style}")

# TODO: support shuffling and checkpointing
self.dataset_name = dataset_name
Expand All @@ -151,6 +151,9 @@ def __init__(
# debugging dataloader yielding
self.special_mode = str(special_mode)

# number of samples to log
self.number_of_samples_to_log = 5

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len

Expand All @@ -162,7 +165,10 @@ def __iter__(self):
continue

for sample_json in self._get_data_iter():
sample_text = self.data_processing_fn(sample_json, self.rng, self.representation_type)
sample_text = self.data_processing_fn(sample_json["text"], self.rng, self.representation_type)
if self.number_of_samples_to_log > 0:
logger.info(f"Sample: {sample_text}")
self.number_of_samples_to_log -= 1
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
self._all_tokens.extend(sample_tokens)
self._sample_idx += 1
Expand Down
15 changes: 9 additions & 6 deletions torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def __init__(self, hash, experiment_name, log_dir, save_aim_folder, enable_aim):
self.writer = AimLogger(save_aim_folder, experiment=experiment_name)
else:
self.writer = AimLogger(save_aim_folder)
self.experiment_hash = self.writer.experiment.hash
else:
self.experiment_hash = "default"

def log(self, metrics: Dict[str, Any], step: int):
if self.writer is not None:
Expand All @@ -116,11 +119,6 @@ def log_hparams(self, config):
if self.writer is not None:
self.writer.experiment['hparams'] = config

@property
def experiment_hash(self):
if self.writer is None:
return "default"
return self.writer._run.hash

def build_metric_logger(
job_config: JobConfig, parallel_dims: ParallelDims
Expand All @@ -144,5 +142,10 @@ def build_metric_logger(
f"Metrics logging active. Aim logs will be saved at /{save_aim_folder}"
)
enable_aim = torch.distributed.get_rank() == 0
return MetricLogger(job_config.metrics.aim_hash, job_config.metrics.aim_experiment_name, log_dir, save_aim_folder, enable_aim)
metric_logger = MetricLogger(job_config.metrics.aim_hash, job_config.metrics.aim_experiment_name, log_dir, save_aim_folder, enable_aim)

experiment_hash_list = [metric_logger.experiment_hash]
# broadcast aim experiment hash to all ranks
torch.distributed.broadcast_object_list(experiment_hash_list, src=0)
metric_logger.experiment_hash = experiment_hash_list[0]
return metric_logger
5 changes: 3 additions & 2 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.models.llama import llama2_configs, llama3_configs, Transformer, download_llama3_weights
from torchtitan.models.llama import llama2_configs, llama3_configs, Transformer, download_llama3_weights, export_llama3_weights
from torchtitan.models.opt import opt_configs, OPT, download_opt_weights, export_opt_weights

models_config = {
Expand All @@ -31,5 +31,6 @@
}

model_name_to_weights_export_fns = {
"opt": export_opt_weights
"opt": export_opt_weights,
"llama3": export_llama3_weights
}
4 changes: 2 additions & 2 deletions torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.models.llama.model import ModelArgs, Transformer
from torchtitan.models.llama.utils import download_llama3_weights
from torchtitan.models.llama.utils import download_llama3_weights, export_llama3_weights

__all__ = ["Transformer", download_llama3_weights]
__all__ = ["Transformer", "download_llama3_weights", "export_llama3_weights"]

llama2_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16),
Expand Down
94 changes: 77 additions & 17 deletions torchtitan/models/llama/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@

from transformers import AutoModelForCausalLM
import torch
from torchtitan.models.llama import Transformer
from torchtitan.logging import logger
import os


# reverse_permute for sliced rotary
Expand All @@ -14,7 +16,7 @@ def permute(w, n_heads, dim1, dim2):
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)


def get_hf_llama3_state_dict_keys_mapping(num_layers: int):
def get_hf_llama3_state_dict_keys_mapping(num_layers: int, include_lm_head: bool=False):
"""
Get a mapping between state dict keys of different implementations.
Expand All @@ -29,8 +31,9 @@ def get_hf_llama3_state_dict_keys_mapping(num_layers: int):
'tok_embeddings.weight': 'model.embed_tokens.weight',
# add layer weight mappings here
'norm.weight': 'model.norm.weight',
# "output.weight": 'lm_head.weight',
}
if include_lm_head:
keys_mapping['output.weight'] = 'lm_head.weight'
for layer in range(num_layers):
keys_mapping.update({
f'layers.{layer}.attention.wq.weight': f'model.layers.{layer}.self_attn.q_proj.weight',
Expand All @@ -47,14 +50,32 @@ def get_hf_llama3_state_dict_keys_mapping(num_layers: int):
return keys_mapping


def download_llama3_weights(model: Transformer, weights_path: str, source: str, token_embedding_size: int):
def verify_logits_matching(
model: Transformer,
hf_model,
tokenizer,
atol: float,
prompts=["Hello world", "The capital of France is "]
):
device = "cuda"
hf_model.to(device)
model.eval()
for prompt in prompts:
data = tokenizer(prompt, return_tensors="pt").to(device)
hf_logits = hf_model(**data).logits
logits = model(data.input_ids)
assert torch.allclose(hf_logits, logits, atol=atol)


def download_llama3_weights(model: Transformer, weights_path: str, tokenizer, source: str, token_embedding_size: int):
"""
write docs
"""
if source == "huggingface":
hf_model = AutoModelForCausalLM.from_pretrained(weights_path)
# hf_model.resize_token_embeddings(new_num_tokens=token_embedding_size)
keys_mapping = get_hf_llama3_state_dict_keys_mapping(model.n_layers)
hf_model.resize_token_embeddings(new_num_tokens=token_embedding_size)
include_lm_head = not model.model_args.share_embeddings
keys_mapping = get_hf_llama3_state_dict_keys_mapping(model.n_layers, include_lm_head)
hf_state_dict = hf_model.state_dict()
corrected_state_dict = {}
for key, value in keys_mapping.items():
Expand All @@ -77,17 +98,56 @@ def download_llama3_weights(model: Transformer, weights_path: str, source: str,
corrected_state_dict["freqs_cis"] = model._precompute_freqs_cis()

model.load_state_dict(corrected_state_dict)
logger.info("Successfully loaded Llama 3 model to the titan model.")

# from transformers import AutoTokenizer
# tok = AutoTokenizer.from_pretrained(weights_path)
# device = "cuda"
# hf_model.to(device)
# model.eval()
# text = "Hello world"
# data = tok(text, return_tensors="pt").to(device)
# hf_logits = hf_model(**data).logits
# logits = model(data.input_ids)
# print(torch.allclose(hf_logits, logits, atol=1e-4))
verify_logits_matching(model=model, hf_model=hf_model, tokenizer=tokenizer, atol=1e-1)
logger.info("Successfully loaded Llama 3 model to titan model.")
else:
raise NotImplemented


def map_n_layers_to_model_name(n_layers):
return {
16: "meta-llama/Llama-3.2-1B",
}[n_layers]


def export_llama3_weights(model: Transformer, save_dir, tokenizer, token_embedding_size: int):
"""
write docs
"""
weights_path = map_n_layers_to_model_name(model.n_layers)
hf_model = AutoModelForCausalLM.from_pretrained(weights_path)
hf_model.resize_token_embeddings(new_num_tokens=token_embedding_size)
include_lm_head = not model.model_args.share_embeddings
keys_mapping = get_hf_llama3_state_dict_keys_mapping(model.n_layers, include_lm_head)
state_dict = model.state_dict()
corrected_state_dict = {}
for key, value in keys_mapping.items():
assert hf_model.state_dict()[value].shape == state_dict[key].shape
if "self_attn.q_proj.weight" in value:
corrected_state_dict[value] = permute(
state_dict[key], model.model_args.n_heads,
model.model_args.dim, model.model_args.dim
)
elif "self_attn.k_proj.weight" in value:
kv_dim = model.model_args.dim // (model.model_args.n_heads // model.model_args.n_kv_heads)
corrected_state_dict[value] = permute(
state_dict[key], model.model_args.n_kv_heads,
kv_dim, model.model_args.dim
)
else:
corrected_state_dict[value] = state_dict[key]

if model.model_args.share_embeddings:
assert hf_model.state_dict()[value].shape == state_dict[key].shape
corrected_state_dict["lm_head.weight"] = state_dict["tok_embeddings.weight"]

hf_model.load_state_dict(corrected_state_dict)
verify_logits_matching(
model=model,
hf_model=hf_model,
tokenizer=tokenizer,
atol=1e-2,
prompts=["", "[QED]", "[SAFE]"]
)
hf_model.save_pretrained(save_dir)
logger.info(f"Successfully exported Llama 3 model to huggingface model at {save_dir}.")
8 changes: 7 additions & 1 deletion torchtitan/models/opt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ModelArgs:
# `False`, each uses the total number of transformer blocks
depth_init: bool = True
norm_type: str = "layernorm_bias"
share_embeddings = False


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down Expand Up @@ -306,7 +307,10 @@ def __init__(self, model_args: ModelArgs):
self.norm = build_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)
self.output = lambda x: F.linear(x, self.tok_embeddings.weight)
# self.output = lambda x: F.linear(x, self.tok_embeddings.weight)
self.output = None
if not self.model_args.share_embeddings:
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)

self.init_weights()

Expand Down Expand Up @@ -361,6 +365,8 @@ def forward(self, tokens: torch.Tensor):
h = layer(h)

h = self.norm(h) if self.norm else h
if self.model_args.share_embeddings:
return torch.matmul(h, self.tok_embeddings.weight.t()).float()
output = self.output(h).float() if self.output else h
return output

Expand Down
Loading

0 comments on commit 8e38015

Please sign in to comment.