Skip to content

Commit

Permalink
Merge pull request #17 from YerevaNN/validation
Browse files Browse the repository at this point in the history
Validation loop
  • Loading branch information
philippguevorguian authored Oct 2, 2024
2 parents be97fc3 + f9b1b33 commit 225c78c
Show file tree
Hide file tree
Showing 14 changed files with 641 additions and 119 deletions.
6 changes: 3 additions & 3 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/bash
#!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

Expand All @@ -11,9 +11,9 @@ set -ex
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh
NGPU=${NGPU:-"2"}
LOG_RANK=0,1
LOG_RANK=${LOG_RANK:-0,1}
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}
MAX_RESTARTS=5
MAX_RESTARTS=0

overrides=""
if [ $# -ne 0 ]; then
Expand Down
4 changes: 2 additions & 2 deletions submitit_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

if __name__ == "__main__":
executor = submitit.AutoExecutor(folder="~/slurm_jobs/titan/job_%j")
n_gpus = 8
n_gpus = 4
executor.update_parameters(
name="titan", timeout_min=3 * 24 * 60,
gpus_per_node=n_gpus,
nodes=1, mem_gb=80, cpus_per_task=n_gpus * 4,
slurm_additional_parameters={
"partition": "h100"
"partition": "a100"
}
)

Expand Down
300 changes: 300 additions & 0 deletions test/assets/chemlactica_valid_mini/file0.jsonl

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions torchtitan/aim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def __init__(
capture_terminal_logs: Optional[bool] = True,
run_name: Optional[str] = None,
run_hash: Optional[str] = None,
train_metric_prefix: Optional[str] = 'train_',
val_metric_prefix: Optional[str] = 'val_',
test_metric_prefix: Optional[str] = 'test_',
train_metric_prefix: Optional[str] = 'train/',
val_metric_prefix: Optional[str] = 'val/',
test_metric_prefix: Optional[str] = 'test/',
):
super().__init__()

Expand Down Expand Up @@ -73,14 +73,14 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
name = k
context = {}
if self._train_metric_prefix and name.startswith(self._train_metric_prefix):
name = name[len(self._train_metric_prefix) :]
context['subset'] = 'train'
name = name[len(self._train_metric_prefix) :]
elif self._test_metric_prefix and name.startswith(self._test_metric_prefix):
name = name[len(self._test_metric_prefix) :]
context['subset'] = 'test'
name = name[len(self._test_metric_prefix) :]
elif self._val_metric_prefix and name.startswith(self._val_metric_prefix):
name = name[len(self._val_metric_prefix) :]
context['subset'] = 'val'
name = name[len(self._val_metric_prefix) :]
self.experiment.track(v, name=name, step=step, context=context)

def finalize(self) -> None:
Expand Down
28 changes: 23 additions & 5 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(self):
)
self.parser.add_argument(
"--metrics.enable_aim",
default=False,
action="store_true",
help="Whether to log metrics to aim",
)
Expand Down Expand Up @@ -205,14 +206,9 @@ def __init__(self):
)
self.parser.add_argument(
"--training.data_processing_style",
choices=["chemlactica_style"],
default="chemlactica_style",
help="""
Specifies the method for processing data prior to tokenization.""",
)
self.parser.add_argument(
"--training.batch_size", type=int, default=8, help="Batch size"
)
self.parser.add_argument(
"--training.gradient_accumulation_steps",
type=int,
Expand Down Expand Up @@ -390,6 +386,28 @@ def __init__(self):
help="Python garbage control scheduling interval, in steps",
)

# validation configs
self.parser.add_argument(
"--validation.batch_size", type=int, default=None
)
self.parser.add_argument(
"--validation.dataset", type=str, help="Dataset to use", default=None
)
self.parser.add_argument(
"--validation.dataset_path",
type=str,
help="""
Path to the dataset for validation in the file system. If provided, data will be
loaded from this path instead of downloaded.""",
default=None,
)
self.parser.add_argument(
"--validation.valid_freq", type=int, default=1024, help="How often to evaluate the model and log metrics to aim."
)
self.parser.add_argument(
"--validation.enable_valid", type=bool, default=False, help="Whether to do validation."
)

# checkpointing configs
self.parser.add_argument(
"--checkpoint.enable_checkpoint",
Expand Down
42 changes: 11 additions & 31 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import pickle
from typing import Any, Dict, List, Optional
from pathlib import Path
import glob
import os

Expand All @@ -26,18 +25,23 @@

from torchtitan.tokenizers.tokenizer import Tokenizer
from torchtitan.logging import logger
from torchtitan.utils.dataset_utils import chemlactica_style_data_processing,create_fresh_file_store
from torchtitan.utils.dataset_utils import chemlactica_style_data_processing

from datasets import load_dataset
from datasets.distributed import split_dataset_by_node

# map from dataset name to a local directory, or
# a dataset repository on the HF hub
_supported_datasets = {
# train
"c4_test": "test/assets/c4_test",
"c4": "allenai/c4",
"chemlactica_train_mini": "test/assets/chemlactica_train_mini",
"chemlactica_train": "/nfs/dgx/raid/chem/data/rdkit_computed_rel+form/train_rdkit_computed_rel+form"
"chemlactica_train": "/nfs/dgx/raid/chem/data/rdkit_computed_rel+form/train_rdkit_computed_rel+form",

# valid
"chemlactica_valid": "/nfs/dgx/raid/chem/data/rdkit_computed_rel+form",
"chemlactica_valid_mini": "test/assets/chemlactica_valid_mini"
}

_supported_data_processing_styles = {
Expand Down Expand Up @@ -93,7 +97,6 @@ def __init__(
rank: int = 0,
infinite: bool = False,
special_mode = None,
store = None,
) -> None:
# allow user to pass in a (local or HF hub) path to use unsupported datasets
if dataset_name not in _supported_datasets:
Expand All @@ -120,13 +123,13 @@ def __init__(
ds = load_dataset(dataset_path, split="train")
else:
dataset_files = glob.glob(os.path.join(dataset_path, "*.jsonl"))
ds = load_dataset("text", data_files=dataset_files, split="train", streaming=True)
ds = load_dataset("text", data_files=dataset_files, split="train", streaming="valid" not in dataset_name)

try:
data_processing_fn = _supported_data_processing_styles[data_processing_style]
except KeyError as e:
raise ValueError(f"Unsupported data processing style: {data_processing_style}")
# data_processing_fn = lambda x, e: str(x)


# TODO: support shuffling and checkpointing
self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
Expand All @@ -138,12 +141,6 @@ def __init__(
self.world_size = world_size
self.representation_type = representation_type

# for non sync communication between ranks
if not self.infinite and store:
self.store = store
else:
self.store = None

# variables for checkpointing
self._sample_idx = 0
self._all_tokens: List[int] = []
Expand All @@ -154,12 +151,6 @@ def __init__(
# debugging dataloader yielding
self.special_mode = str(special_mode)

def _some_rank_finished(self) -> bool:
if not self.infinite and self.store.num_keys() > 1: # one key used for coordination, more than one means one of the ranks exhausted data
return True
else:
return False

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len

Expand All @@ -171,8 +162,6 @@ def __iter__(self):
continue

for sample_json in self._get_data_iter():
if self._some_rank_finished():
break
sample_text = self.data_processing_fn(sample_json, self.rng, self.representation_type)
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
self._all_tokens.extend(sample_tokens)
Expand All @@ -187,9 +176,7 @@ def __iter__(self):
yield input, label

if not self.infinite:
self.store.set(str(self.rank),"Done")
logger.warning(f"Dataset {self.dataset_name} has run out of data")
self.store.wait([str(k) for k in range(self.world_size)]) # making sure all ranks get to this point
break
else:
# Reset offset for the next iteration
Expand Down Expand Up @@ -261,16 +248,9 @@ def build_hf_data_loader(
pin_memory: bool = False,
num_workers: int = 2,
special_mode = None,
context = "train",
):
if not infinite:
store_identifier = f"rankstore_{context}_{dataset_name}"
data_completion_store = create_fresh_file_store(store_identifier,world_size)
else:
data_completion_store = None

hf_ds = HuggingFaceDataset(
dataset_name, dataset_path, data_processing_style, tokenizer, representation_type, seq_len, world_size, rank, infinite, special_mode,store = data_completion_store
dataset_name, dataset_path, data_processing_style, tokenizer, representation_type, seq_len, world_size, rank, infinite, special_mode
)

return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers)
27 changes: 27 additions & 0 deletions torchtitan/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dataclasses import dataclass
from datetime import timedelta
from typing import Union
import contextlib

import torch
import torch.distributed._functional_collectives as funcol
Expand All @@ -17,6 +18,21 @@
from torchtitan.logging import logger


def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
@contextlib.contextmanager
def context():
with contextlib.ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())
if enable_compiled_autograd:
stack.enter_context(
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)
yield

return context


def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float:
tensor = torch.tensor(x).cuda()
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh).item()
Expand Down Expand Up @@ -133,6 +149,17 @@ def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int:
return flop_per_token


def get_num_flop_per_token_forward(num_params: int, model_config, seq_len) -> int:
l, h, q, t = (
model_config.n_layers,
model_config.n_heads,
model_config.dim // model_config.n_heads,
seq_len,
)
flop_per_token = 2 * num_params + 4 * l * h * q * t
return flop_per_token


# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU
def get_peak_flops(device_name: str) -> int:
if "A100" in device_name:
Expand Down
10 changes: 0 additions & 10 deletions torchtitan/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,6 @@
import os
from pathlib import Path

TEMPORARY_FILES_PATH = Path('/tmp')

def create_fresh_file_store(store_identifier: str, world_size: int):
store_file = TEMPORARY_FILES_PATH.joinpath(store_identifier)
if store_file.exists():
store_file.unlink() # we want to always remove prior files since they don't correspond

stop_ranks_store = torch.distributed.FileStore(str(store_file),world_size)
return stop_ranks_store


def load_jsonl_line(jsonl_line):
try:
Expand Down
Loading

0 comments on commit 225c78c

Please sign in to comment.