Skip to content

Commit

Permalink
Merge pull request #60 from AGI-Collective/dataloader
Browse files Browse the repository at this point in the history
Added dataloader checkpointing
  • Loading branch information
kshitijkg authored Oct 24, 2023
2 parents 8e52a4f + 0330e50 commit 8a8e91d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 27 deletions.
50 changes: 34 additions & 16 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@

from streaming import Stream, StreamingDataset
from omegaconf import OmegaConf as om
import pickle as pkl

import os

def make_data_loader(dataset, neox_args):
"""Build dataloader given an input dataset."""
if dataset is None:
Expand Down Expand Up @@ -367,25 +371,39 @@ def prepare_config(dataset_config):
neox_args.do_train = flags[0].item()
neox_args.do_valid = flags[1].item()
neox_args.do_test = flags[2].item()


# Build iterators.
if train_dataloader is not None:
train_data_iterator = iter(train_dataloader)
else:
train_data_iterator = None

# Shift the start iterations.
if train_dataloader is not None:
train_state_dict_path = neox_args.train_streaming_data_config['state_dict_path']
if os.path.exists(train_state_dict_path):
train_state_dict = pkl.load(open(train_state_dict_path, 'rb'))
print(train_state_dict)
train_dataloader.load_state_dict(train_state_dict)
# Print all the key value pairs of the state dict
for k, v in train_state_dict.items():
print_rank_0(f"Loaded {k} with value {v}")
else:
print_rank_0(
"setting training data start iteration to {}".format(
0
)
)

if valid_dataloader is not None:
valid_data_iterator = iter(valid_dataloader)
else:
valid_data_iterator = None

if test_dataloader is not None:
test_data_iterator = iter(test_dataloader)
else:
test_data_iterator = None
valid_state_dict_path = neox_args.valid_streaming_data_config['state_dict_path']
if os.path.exists(valid_state_dict_path):
valid_state_dict = pkl.load(open(valid_state_dict_path, 'rb'))
valid_dataloader.load_state_dict(valid_state_dict)
for k, v in train_state_dict.items():
print_rank_0(f"Loaded {k} with value {v}")
else:
print_rank_0(
"setting validation data start iteration to {}".format(
0
)
)

return train_data_iterator, valid_data_iterator, test_data_iterator
return train_dataloader, valid_dataloader, test_dataloader

def build_train_valid_test_data_iterators(neox_args):
"""XXX"""
Expand Down
31 changes: 20 additions & 11 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
)
from megatron.model.gpt2_model import cross_entropy
from eval_tasks import run_eval_harness
import pickle as pkl


def mup_weights_reinit(neox_args, model):
Expand Down Expand Up @@ -198,14 +199,14 @@ def pretrain(neox_args):
# Data stuff.
timers("train/valid/test data iterators").start()
(
train_data_iterator,
valid_data_iterator,
test_data_iterator,
train_dataloader,
valid_dataloader,
test_dataloader,
) = build_streaming_train_valid_test_data_iterators(neox_args=neox_args)
timers("train/valid/test data iterators").stop()

if neox_args.use_mup and neox_args.coord_check:
mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator)
mup_coord_check(neox_args, timers, lr_scheduler, iter(train_dataloader) if train_dataloader is not None else None)

# Print setup timing.
print_rank_0("done with setups ...")
Expand All @@ -223,15 +224,18 @@ def pretrain(neox_args):
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
pkl.dump(train_dataloader.state_dict(), open(neox_args.train_streaming_data_config['state_dict_path'], 'wb'))
pkl.dump(valid_dataloader.state_dict(), open(neox_args.valid_streaming_data_config['state_dict_path'], 'wb'))


iteration = train(
neox_args=neox_args,
timers=timers,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
train_data_iterator=train_data_iterator,
valid_data_iterator=valid_data_iterator,
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
)

if neox_args.do_valid:
Expand All @@ -240,7 +244,7 @@ def pretrain(neox_args):
neox_args=neox_args,
prefix=prefix,
forward_step_func=forward_step,
data_iterator=valid_data_iterator,
data_iterator=iter(valid_dataloader) if valid_dataloader is not None else None,
model=model,
iteration=iteration,
verbose=False,
Expand All @@ -255,6 +259,8 @@ def pretrain(neox_args):
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
pkl.dump(train_dataloader.state_dict(), open(neox_args.train_streaming_data_config['state_dict_path'], 'wb'))
pkl.dump(valid_dataloader.state_dict(), open(neox_args.valid_streaming_data_config['state_dict_path'], 'wb'))

if neox_args.do_test:
# Run on test data.
Expand All @@ -263,7 +269,7 @@ def pretrain(neox_args):
neox_args=neox_args,
prefix=prefix,
forward_step_func=forward_step,
data_iterator=test_data_iterator,
data_iterator=iter(test_dataloader) if test_dataloader is not None else None,
model=model,
iteration=iteration,
verbose=True,
Expand Down Expand Up @@ -817,11 +823,12 @@ def train(
model,
optimizer,
lr_scheduler,
train_data_iterator,
valid_data_iterator,
train_dataloader,
valid_dataloader,
):
"""Train the model function."""

train_data_iterator = iter(train_dataloader) if train_dataloader else None
valid_data_iterator = iter(valid_dataloader) if valid_dataloader else None
# Turn on training mode which enables dropout.
model.train()

Expand Down Expand Up @@ -887,6 +894,8 @@ def train(
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
pkl.dump(train_dataloader.state_dict(), open(neox_args.train_streaming_data_config['state_dict_path'], 'wb'))
pkl.dump(valid_dataloader.state_dict(), open(neox_args.valid_streaming_data_config['state_dict_path'], 'wb'))

# Evaluation
if (
Expand Down

0 comments on commit 8a8e91d

Please sign in to comment.