diff --git a/PaddleNLP/benchmark/transformer/configs/transformer.big.yaml b/PaddleNLP/benchmark/transformer/configs/transformer.big.yaml index 05a6520110..fa321f1605 100644 --- a/PaddleNLP/benchmark/transformer/configs/transformer.big.yaml +++ b/PaddleNLP/benchmark/transformer/configs/transformer.big.yaml @@ -27,6 +27,12 @@ pool_size: 200000 sort_type: "global" batch_size: 4096 infer_batch_size: 16 +shuffle_batch: True +# Data shuffle only works when sort_type is pool or none +shuffle: True +# shuffle_seed must be set when shuffle is True and using multi-cards to train. +# Otherwise, the number of batches cannot be guaranteed. +shuffle_seed: 128 # Hyparams for training: # The number of epoches for training diff --git a/PaddleNLP/benchmark/transformer/reader.py b/PaddleNLP/benchmark/transformer/reader.py index 38fcda422f..9e3f86ee0e 100644 --- a/PaddleNLP/benchmark/transformer/reader.py +++ b/PaddleNLP/benchmark/transformer/reader.py @@ -43,6 +43,12 @@ def create_data_loader(args): mode=m, transform_func=transform_func) for m in ["train", "dev"] ] + if args.shuffle or args.shuffle_batch: + if args.shuffle_seed == "None" or args.shuffle_seed is None: + shuffle_seed = 0 + else: + shuffle_seed = args.shuffle_seed + def _max_token_fn(current_idx, current_batch_size, tokens_sofar, data_source): return max(tokens_sofar, @@ -69,7 +75,8 @@ def _key(size_so_far, minibatch_len): key=trg_key, buffer_size=buffer_size).sort( key=src_key, buffer_size=buffer_size) else: - sampler = sampler.shuffle() + if args.shuffle: + sampler = sampler.shuffle(seed=shuffle_seed) if args.sort_type == SortType.POOL: buffer_size = args.pool_size sampler = sampler.sort(key=src_key, buffer_size=buffer_size) @@ -83,6 +90,9 @@ def _key(size_so_far, minibatch_len): if m == "train": batch_sampler = batch_sampler.shard() + if args.shuffle_batch: + batch_sampler.shuffle(seed=shuffle_seed) + data_loader = DataLoader( dataset=dataset, batch_sampler=batch_sampler, diff --git a/PaddleNLP/examples/machine_translation/transformer/configs/transformer.base.yaml b/PaddleNLP/examples/machine_translation/transformer/configs/transformer.base.yaml index 7ea9ebbe71..57070dc296 100644 --- a/PaddleNLP/examples/machine_translation/transformer/configs/transformer.base.yaml +++ b/PaddleNLP/examples/machine_translation/transformer/configs/transformer.base.yaml @@ -27,6 +27,12 @@ pool_size: 200000 sort_type: "global" batch_size: 4096 infer_batch_size: 8 +shuffle_batch: True +# Data shuffle only works when sort_type is pool or none +shuffle: True +# shuffle_seed must be set when shuffle is True and using multi-cards to train. +# Otherwise, the number of batches cannot be guaranteed. +shuffle_seed: 128 # Hyparams for training: # The number of epoches for training diff --git a/PaddleNLP/examples/machine_translation/transformer/configs/transformer.big.yaml b/PaddleNLP/examples/machine_translation/transformer/configs/transformer.big.yaml index d458f4c7eb..4cd3b1201f 100644 --- a/PaddleNLP/examples/machine_translation/transformer/configs/transformer.big.yaml +++ b/PaddleNLP/examples/machine_translation/transformer/configs/transformer.big.yaml @@ -27,6 +27,12 @@ pool_size: 200000 sort_type: "global" batch_size: 4096 infer_batch_size: 8 +shuffle_batch: True +# Data shuffle only works when sort_type is pool or none +shuffle: True +# shuffle_seed must be set when shuffle is True and using multi-cards to train. +# Otherwise, the number of batches cannot be guaranteed. +shuffle_seed: 128 # Hyparams for training: # The number of epoches for training diff --git a/PaddleNLP/examples/machine_translation/transformer/reader.py b/PaddleNLP/examples/machine_translation/transformer/reader.py index 38fcda422f..9e3f86ee0e 100644 --- a/PaddleNLP/examples/machine_translation/transformer/reader.py +++ b/PaddleNLP/examples/machine_translation/transformer/reader.py @@ -43,6 +43,12 @@ def create_data_loader(args): mode=m, transform_func=transform_func) for m in ["train", "dev"] ] + if args.shuffle or args.shuffle_batch: + if args.shuffle_seed == "None" or args.shuffle_seed is None: + shuffle_seed = 0 + else: + shuffle_seed = args.shuffle_seed + def _max_token_fn(current_idx, current_batch_size, tokens_sofar, data_source): return max(tokens_sofar, @@ -69,7 +75,8 @@ def _key(size_so_far, minibatch_len): key=trg_key, buffer_size=buffer_size).sort( key=src_key, buffer_size=buffer_size) else: - sampler = sampler.shuffle() + if args.shuffle: + sampler = sampler.shuffle(seed=shuffle_seed) if args.sort_type == SortType.POOL: buffer_size = args.pool_size sampler = sampler.sort(key=src_key, buffer_size=buffer_size) @@ -83,6 +90,9 @@ def _key(size_so_far, minibatch_len): if m == "train": batch_sampler = batch_sampler.shard() + if args.shuffle_batch: + batch_sampler.shuffle(seed=shuffle_seed) + data_loader = DataLoader( dataset=dataset, batch_sampler=batch_sampler,