Skip to content

Commit 4d87afd

Browse files
authored
Fix hung (#5121)
* fix hung * add shuffle batch * update * reader_seed to shuffle_seed * seed for shuffle batch
1 parent 047b8b6 commit 4d87afd

File tree

5 files changed

+40
-2
lines changed

5 files changed

+40
-2
lines changed

PaddleNLP/benchmark/transformer/configs/transformer.big.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ pool_size: 200000
2727
sort_type: "global"
2828
batch_size: 4096
2929
infer_batch_size: 16
30+
shuffle_batch: True
31+
# Data shuffle only works when sort_type is pool or none
32+
shuffle: True
33+
# shuffle_seed must be set when shuffle is True and using multi-cards to train.
34+
# Otherwise, the number of batches cannot be guaranteed.
35+
shuffle_seed: 128
3036

3137
# Hyparams for training:
3238
# The number of epoches for training

PaddleNLP/benchmark/transformer/reader.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ def create_data_loader(args):
4343
mode=m, transform_func=transform_func) for m in ["train", "dev"]
4444
]
4545

46+
if args.shuffle or args.shuffle_batch:
47+
if args.shuffle_seed == "None" or args.shuffle_seed is None:
48+
shuffle_seed = 0
49+
else:
50+
shuffle_seed = args.shuffle_seed
51+
4652
def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
4753
data_source):
4854
return max(tokens_sofar,
@@ -69,7 +75,8 @@ def _key(size_so_far, minibatch_len):
6975
key=trg_key, buffer_size=buffer_size).sort(
7076
key=src_key, buffer_size=buffer_size)
7177
else:
72-
sampler = sampler.shuffle()
78+
if args.shuffle:
79+
sampler = sampler.shuffle(seed=shuffle_seed)
7380
if args.sort_type == SortType.POOL:
7481
buffer_size = args.pool_size
7582
sampler = sampler.sort(key=src_key, buffer_size=buffer_size)
@@ -83,6 +90,9 @@ def _key(size_so_far, minibatch_len):
8390
if m == "train":
8491
batch_sampler = batch_sampler.shard()
8592

93+
if args.shuffle_batch:
94+
batch_sampler.shuffle(seed=shuffle_seed)
95+
8696
data_loader = DataLoader(
8797
dataset=dataset,
8898
batch_sampler=batch_sampler,

PaddleNLP/examples/machine_translation/transformer/configs/transformer.base.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ pool_size: 200000
2727
sort_type: "global"
2828
batch_size: 4096
2929
infer_batch_size: 8
30+
shuffle_batch: True
31+
# Data shuffle only works when sort_type is pool or none
32+
shuffle: True
33+
# shuffle_seed must be set when shuffle is True and using multi-cards to train.
34+
# Otherwise, the number of batches cannot be guaranteed.
35+
shuffle_seed: 128
3036

3137
# Hyparams for training:
3238
# The number of epoches for training

PaddleNLP/examples/machine_translation/transformer/configs/transformer.big.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ pool_size: 200000
2727
sort_type: "global"
2828
batch_size: 4096
2929
infer_batch_size: 8
30+
shuffle_batch: True
31+
# Data shuffle only works when sort_type is pool or none
32+
shuffle: True
33+
# shuffle_seed must be set when shuffle is True and using multi-cards to train.
34+
# Otherwise, the number of batches cannot be guaranteed.
35+
shuffle_seed: 128
3036

3137
# Hyparams for training:
3238
# The number of epoches for training

PaddleNLP/examples/machine_translation/transformer/reader.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ def create_data_loader(args):
4343
mode=m, transform_func=transform_func) for m in ["train", "dev"]
4444
]
4545

46+
if args.shuffle or args.shuffle_batch:
47+
if args.shuffle_seed == "None" or args.shuffle_seed is None:
48+
shuffle_seed = 0
49+
else:
50+
shuffle_seed = args.shuffle_seed
51+
4652
def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
4753
data_source):
4854
return max(tokens_sofar,
@@ -69,7 +75,8 @@ def _key(size_so_far, minibatch_len):
6975
key=trg_key, buffer_size=buffer_size).sort(
7076
key=src_key, buffer_size=buffer_size)
7177
else:
72-
sampler = sampler.shuffle()
78+
if args.shuffle:
79+
sampler = sampler.shuffle(seed=shuffle_seed)
7380
if args.sort_type == SortType.POOL:
7481
buffer_size = args.pool_size
7582
sampler = sampler.sort(key=src_key, buffer_size=buffer_size)
@@ -83,6 +90,9 @@ def _key(size_so_far, minibatch_len):
8390
if m == "train":
8491
batch_sampler = batch_sampler.shard()
8592

93+
if args.shuffle_batch:
94+
batch_sampler.shuffle(seed=shuffle_seed)
95+
8696
data_loader = DataLoader(
8797
dataset=dataset,
8898
batch_sampler=batch_sampler,

0 commit comments

Comments
 (0)