Skip to content

Commit

Permalink
Merge pull request #59 from AGI-Collective/cleaned_padding
Browse files Browse the repository at this point in the history
Cleaned padding
  • Loading branch information
kshitijkg authored Oct 26, 2023
2 parents 3dbc199 + 3803f5d commit 8a6afa1
Show file tree
Hide file tree
Showing 13 changed files with 599 additions and 287 deletions.
60 changes: 42 additions & 18 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@

from streaming import Stream, StreamingDataset
from omegaconf import OmegaConf as om
import pickle as pkl

import os

def make_data_loader(dataset, neox_args):
"""Build dataloader given an input dataset."""
if dataset is None:
Expand Down Expand Up @@ -324,8 +328,9 @@ def build_streaming_train_valid_test_data_iterators(neox_args):
def prepare_config(dataset_config):
dataset_config['num_workers'] = neox_args.num_workers
dataset_config['dataset']['max_seq_length'] = neox_args.seq_length
dataset_config['dataset']['eos_token_id'] = neox_args.tokenizer.eod_id
dataset_config['dataset']['remote'] = None # TODO Allow remote datasets
dataset_config['dataset']['position_pad_id'] = neox_args.position_pad_id
dataset_config['dataset']['vision_pad_id'] = neox_args.vision_pad_id

prepare_config(neox_args.train_streaming_data_config)
prepare_config(neox_args.valid_streaming_data_config)
Expand All @@ -340,9 +345,7 @@ def prepare_config(dataset_config):
tokenizer = neox_args.tokenizer

train_dataloader = build_interleaved_dataloader(train_dataset_cfg, tokenizer, device_batch_size)
train_dataset_cfg['dataset']['split'] = "validation"
valid_dataloader = build_interleaved_dataloader(validation_dataset_cfg, tokenizer, device_batch_size)
validation_dataset_cfg['dataset']['split'] = "test"
test_dataloader = build_interleaved_dataloader(test_dataset_cfg, tokenizer, device_batch_size)

# Flags to know if we need to do training/validation/testing.
Expand All @@ -368,25 +371,46 @@ def prepare_config(dataset_config):
neox_args.do_train = flags[0].item()
neox_args.do_valid = flags[1].item()
neox_args.do_test = flags[2].item()


# Build iterators.

# Shift the start iterations.
if train_dataloader is not None:
train_data_iterator = iter(train_dataloader)
else:
train_data_iterator = None
train_state_dict_path = neox_args.train_streaming_data_config['state_dict_path']
if os.path.exists(train_state_dict_path):
file_name = os.path.join(train_state_dict_path, f'{neox_args.iteration}_checkpoint.pkl')

if os.path.isfile(file_name): # If the file exists
train_state_dict = pkl.load(open(file_name, 'rb')) # Load the file
print(train_state_dict)
train_dataloader.load_state_dict(train_state_dict)
else:
print("No matching state dict found.")

else:
print_rank_0(
"setting training data start iteration to {}".format(
0
)
)

if valid_dataloader is not None:
valid_data_iterator = iter(valid_dataloader)
else:
valid_data_iterator = None

if test_dataloader is not None:
test_data_iterator = iter(test_dataloader)
else:
test_data_iterator = None
valid_state_dict_path = neox_args.valid_streaming_data_config['state_dict_path']
if os.path.exists(valid_state_dict_path):
file_name = os.path.join(valid_state_dict_path, f'{neox_args.iteration}_checkpoint.pkl')

if os.path.isfile(file_name): # If the file exists
valid_state_dict = pkl.load(open(file_name, 'rb')) # Load the file
print(valid_state_dict)
valid_dataloader.load_state_dict(valid_state_dict)
else:
print("No matching state dict found.")
else:
print_rank_0(
"setting validation data start iteration to {}".format(
0
)
)

return train_data_iterator, valid_data_iterator, test_data_iterator
return train_dataloader, valid_dataloader, test_dataloader

def build_train_valid_test_data_iterators(neox_args):
"""XXX"""
Expand Down
Loading

0 comments on commit 8a6afa1

Please sign in to comment.