Skip to content

Commit

Permalink
Fix hung (#5121)
Browse files Browse the repository at this point in the history
* fix hung

* add shuffle batch

* update

* reader_seed to shuffle_seed

* seed for shuffle batch
  • Loading branch information
FrostML authored Dec 19, 2020
1 parent 047b8b6 commit 4d87afd
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 2 deletions.
6 changes: 6 additions & 0 deletions PaddleNLP/benchmark/transformer/configs/transformer.big.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ pool_size: 200000
sort_type: "global"
batch_size: 4096
infer_batch_size: 16
shuffle_batch: True
# Data shuffle only works when sort_type is pool or none
shuffle: True
# shuffle_seed must be set when shuffle is True and using multi-cards to train.
# Otherwise, the number of batches cannot be guaranteed.
shuffle_seed: 128

# Hyparams for training:
# The number of epoches for training
Expand Down
12 changes: 11 additions & 1 deletion PaddleNLP/benchmark/transformer/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def create_data_loader(args):
mode=m, transform_func=transform_func) for m in ["train", "dev"]
]

if args.shuffle or args.shuffle_batch:
if args.shuffle_seed == "None" or args.shuffle_seed is None:
shuffle_seed = 0
else:
shuffle_seed = args.shuffle_seed

def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
data_source):
return max(tokens_sofar,
Expand All @@ -69,7 +75,8 @@ def _key(size_so_far, minibatch_len):
key=trg_key, buffer_size=buffer_size).sort(
key=src_key, buffer_size=buffer_size)
else:
sampler = sampler.shuffle()
if args.shuffle:
sampler = sampler.shuffle(seed=shuffle_seed)
if args.sort_type == SortType.POOL:
buffer_size = args.pool_size
sampler = sampler.sort(key=src_key, buffer_size=buffer_size)
Expand All @@ -83,6 +90,9 @@ def _key(size_so_far, minibatch_len):
if m == "train":
batch_sampler = batch_sampler.shard()

if args.shuffle_batch:
batch_sampler.shuffle(seed=shuffle_seed)

data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ pool_size: 200000
sort_type: "global"
batch_size: 4096
infer_batch_size: 8
shuffle_batch: True
# Data shuffle only works when sort_type is pool or none
shuffle: True
# shuffle_seed must be set when shuffle is True and using multi-cards to train.
# Otherwise, the number of batches cannot be guaranteed.
shuffle_seed: 128

# Hyparams for training:
# The number of epoches for training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ pool_size: 200000
sort_type: "global"
batch_size: 4096
infer_batch_size: 8
shuffle_batch: True
# Data shuffle only works when sort_type is pool or none
shuffle: True
# shuffle_seed must be set when shuffle is True and using multi-cards to train.
# Otherwise, the number of batches cannot be guaranteed.
shuffle_seed: 128

# Hyparams for training:
# The number of epoches for training
Expand Down
12 changes: 11 additions & 1 deletion PaddleNLP/examples/machine_translation/transformer/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def create_data_loader(args):
mode=m, transform_func=transform_func) for m in ["train", "dev"]
]

if args.shuffle or args.shuffle_batch:
if args.shuffle_seed == "None" or args.shuffle_seed is None:
shuffle_seed = 0
else:
shuffle_seed = args.shuffle_seed

def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
data_source):
return max(tokens_sofar,
Expand All @@ -69,7 +75,8 @@ def _key(size_so_far, minibatch_len):
key=trg_key, buffer_size=buffer_size).sort(
key=src_key, buffer_size=buffer_size)
else:
sampler = sampler.shuffle()
if args.shuffle:
sampler = sampler.shuffle(seed=shuffle_seed)
if args.sort_type == SortType.POOL:
buffer_size = args.pool_size
sampler = sampler.sort(key=src_key, buffer_size=buffer_size)
Expand All @@ -83,6 +90,9 @@ def _key(size_so_far, minibatch_len):
if m == "train":
batch_sampler = batch_sampler.shard()

if args.shuffle_batch:
batch_sampler.shuffle(seed=shuffle_seed)

data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
Expand Down

0 comments on commit 4d87afd

Please sign in to comment.