diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index b86acf33d..00274ce50 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -32,6 +32,10 @@ from streaming import Stream, StreamingDataset from omegaconf import OmegaConf as om +import pickle as pkl + +import os + def make_data_loader(dataset, neox_args): """Build dataloader given an input dataset.""" if dataset is None: @@ -367,25 +371,39 @@ def prepare_config(dataset_config): neox_args.do_train = flags[0].item() neox_args.do_valid = flags[1].item() neox_args.do_test = flags[2].item() - - - # Build iterators. - if train_dataloader is not None: - train_data_iterator = iter(train_dataloader) - else: - train_data_iterator = None + # Shift the start iterations. + if train_dataloader is not None: + train_state_dict_path = neox_args.train_streaming_data_config['state_dict_path'] + if os.path.exists(train_state_dict_path): + train_state_dict = pkl.load(open(train_state_dict_path, 'rb')) + print(train_state_dict) + train_dataloader.load_state_dict(train_state_dict) + # Print all the key value pairs of the state dict + for k, v in train_state_dict.items(): + print_rank_0(f"Loaded {k} with value {v}") + else: + print_rank_0( + "setting training data start iteration to {}".format( + 0 + ) + ) + if valid_dataloader is not None: - valid_data_iterator = iter(valid_dataloader) - else: - valid_data_iterator = None - - if test_dataloader is not None: - test_data_iterator = iter(test_dataloader) - else: - test_data_iterator = None + valid_state_dict_path = neox_args.valid_streaming_data_config['state_dict_path'] + if os.path.exists(valid_state_dict_path): + valid_state_dict = pkl.load(open(valid_state_dict_path, 'rb')) + valid_dataloader.load_state_dict(valid_state_dict) + for k, v in train_state_dict.items(): + print_rank_0(f"Loaded {k} with value {v}") + else: + print_rank_0( + "setting validation data start iteration to {}".format( + 0 + ) + ) - return train_data_iterator, valid_data_iterator, test_data_iterator + return train_dataloader, valid_dataloader, test_dataloader def build_train_valid_test_data_iterators(neox_args): """XXX""" diff --git a/megatron/training.py b/megatron/training.py index 9e39287ef..171265f2c 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -57,6 +57,7 @@ ) from megatron.model.gpt2_model import cross_entropy from eval_tasks import run_eval_harness +import pickle as pkl def mup_weights_reinit(neox_args, model): @@ -198,14 +199,14 @@ def pretrain(neox_args): # Data stuff. timers("train/valid/test data iterators").start() ( - train_data_iterator, - valid_data_iterator, - test_data_iterator, + train_dataloader, + valid_dataloader, + test_dataloader, ) = build_streaming_train_valid_test_data_iterators(neox_args=neox_args) timers("train/valid/test data iterators").stop() if neox_args.use_mup and neox_args.coord_check: - mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator) + mup_coord_check(neox_args, timers, lr_scheduler, iter(train_dataloader) if train_dataloader is not None else None) # Print setup timing. print_rank_0("done with setups ...") @@ -223,6 +224,9 @@ def pretrain(neox_args): optimizer=optimizer, lr_scheduler=lr_scheduler, ) + pkl.dump(train_dataloader.state_dict(), open(neox_args.train_streaming_data_config['state_dict_path'], 'wb')) + pkl.dump(valid_dataloader.state_dict(), open(neox_args.valid_streaming_data_config['state_dict_path'], 'wb')) + iteration = train( neox_args=neox_args, @@ -230,8 +234,8 @@ def pretrain(neox_args): model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, - train_data_iterator=train_data_iterator, - valid_data_iterator=valid_data_iterator, + train_dataloader=train_dataloader, + valid_dataloader=valid_dataloader, ) if neox_args.do_valid: @@ -240,7 +244,7 @@ def pretrain(neox_args): neox_args=neox_args, prefix=prefix, forward_step_func=forward_step, - data_iterator=valid_data_iterator, + data_iterator=iter(valid_dataloader) if valid_dataloader is not None else None, model=model, iteration=iteration, verbose=False, @@ -255,6 +259,8 @@ def pretrain(neox_args): optimizer=optimizer, lr_scheduler=lr_scheduler, ) + pkl.dump(train_dataloader.state_dict(), open(neox_args.train_streaming_data_config['state_dict_path'], 'wb')) + pkl.dump(valid_dataloader.state_dict(), open(neox_args.valid_streaming_data_config['state_dict_path'], 'wb')) if neox_args.do_test: # Run on test data. @@ -263,7 +269,7 @@ def pretrain(neox_args): neox_args=neox_args, prefix=prefix, forward_step_func=forward_step, - data_iterator=test_data_iterator, + data_iterator=iter(test_dataloader) if test_dataloader is not None else None, model=model, iteration=iteration, verbose=True, @@ -817,11 +823,12 @@ def train( model, optimizer, lr_scheduler, - train_data_iterator, - valid_data_iterator, + train_dataloader, + valid_dataloader, ): """Train the model function.""" - + train_data_iterator = iter(train_dataloader) if train_dataloader else None + valid_data_iterator = iter(valid_dataloader) if valid_dataloader else None # Turn on training mode which enables dropout. model.train() @@ -887,6 +894,8 @@ def train( optimizer=optimizer, lr_scheduler=lr_scheduler, ) + pkl.dump(train_dataloader.state_dict(), open(neox_args.train_streaming_data_config['state_dict_path'], 'wb')) + pkl.dump(valid_dataloader.state_dict(), open(neox_args.valid_streaming_data_config['state_dict_path'], 'wb')) # Evaluation if (