Skip to content

Commit e7a6ac2

Browse files
committed
Fix pipeline
1 parent ac3f1aa commit e7a6ac2

File tree

2 files changed

+70
-30
lines changed

2 files changed

+70
-30
lines changed

paddlenlp/data/dist_dataloader.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
nested_reduce_tensor,
2424
)
2525

26-
_MAX_DATA_DIM = 64
27-
2826

2927
class DummyDataset(paddle.io.Dataset):
3028
"""
@@ -58,6 +56,7 @@ def __init__(
5856
timeout=0,
5957
worker_init_fn=None,
6058
persistent_workers=False,
59+
eval=False,
6160
):
6261

6362
if dataset is None:
@@ -67,6 +66,7 @@ def __init__(
6766
super().__init__(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=num_workers)
6867

6968
self._hcg = fleet.get_hybrid_communicate_group()
69+
self.eval = eval
7070

7171
# Init pp data comm group.
7272
if self._hcg.get_pipe_parallel_world_size() > 1:
@@ -128,8 +128,11 @@ def _init_dataloader_comm_group(self):
128128
parallel_groups = topo.get_comm_list("pipe")
129129

130130
for group in parallel_groups:
131-
# only first rank and last rank
132-
ranks = [group[0], group[-1]]
131+
if not self.eval:
132+
# only first rank and last rank
133+
ranks = [group[0], group[-1]]
134+
else:
135+
ranks = group
133136
comm_group = paddle.distributed.new_group(ranks=ranks)
134137
if paddle.distributed.get_rank() in ranks:
135138
parallel_comm_group = comm_group
@@ -170,6 +173,9 @@ def _broadcast_data(self, data):
170173
src=self._pp_data_group.ranks[0],
171174
group=self._pp_data_group,
172175
)
176+
else:
177+
fake_data = [None]
178+
173179
fake_data = fake_data[0]
174180

175181
if self.mp_group.nranks > 1:
@@ -178,7 +184,6 @@ def _broadcast_data(self, data):
178184
if self._pp_data_group is not None:
179185
if process_rank != self._pp_data_group.ranks[0]:
180186
data = nested_empty_tensor(fake_data)
181-
data = nested_copy_place(data, place=paddle.framework._current_expected_place())
182187

183188
if self.mp_group.nranks > 1 and self.pp_rank == 0:
184189
data = nested_broadcast_tensor(data, src=self.mp_src_rank, group=self.mp_group)
@@ -195,6 +200,7 @@ def __next__(self):
195200
if self._need_data:
196201
try:
197202
data = next(self._dataloader_iter)
203+
data = nested_copy_place(data, place=paddle.framework._current_expected_place())
198204
except:
199205
pass
200206
data = self._broadcast_data(data)

paddlenlp/trainer/trainer.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,24 +1447,41 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
14471447
process_index=self.args.dataset_rank,
14481448
)
14491449

1450-
return _DataLoader(
1451-
eval_dataset,
1452-
batch_size=self.args.per_device_eval_batch_size,
1453-
collate_fn=self.data_collator,
1454-
num_workers=self.args.dataloader_num_workers,
1455-
)
1450+
if self.args.distributed_dataloader:
1451+
return _DataLoader(
1452+
eval_dataset,
1453+
batch_size=self.args.per_device_eval_batch_size,
1454+
collate_fn=self.data_collator,
1455+
num_workers=self.args.dataloader_num_workers,
1456+
eval=True,
1457+
)
1458+
else:
1459+
return _DataLoader(
1460+
eval_dataset,
1461+
batch_size=self.args.per_device_eval_batch_size,
1462+
collate_fn=self.data_collator,
1463+
num_workers=self.args.dataloader_num_workers,
1464+
)
14561465

14571466
eval_sampler = self._get_eval_sampler(eval_dataset)
14581467

14591468
if self.args.distributed_dataloader:
14601469
logger.info("Eval using DistDataLoader.")
14611470

1462-
return _DataLoader(
1463-
eval_dataset,
1464-
batch_sampler=eval_sampler,
1465-
collate_fn=self.data_collator,
1466-
num_workers=self.args.dataloader_num_workers,
1467-
)
1471+
return _DataLoader(
1472+
eval_dataset,
1473+
batch_sampler=eval_sampler,
1474+
collate_fn=self.data_collator,
1475+
num_workers=self.args.dataloader_num_workers,
1476+
eval=True,
1477+
)
1478+
else:
1479+
return _DataLoader(
1480+
eval_dataset,
1481+
batch_sampler=eval_sampler,
1482+
collate_fn=self.data_collator,
1483+
num_workers=self.args.dataloader_num_workers,
1484+
)
14681485

14691486
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
14701487
"""
@@ -1497,25 +1514,42 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
14971514
process_index=self.args.dataset_rank,
14981515
)
14991516

1500-
return _DataLoader(
1501-
test_dataset,
1502-
batch_size=self.args.per_device_eval_batch_size * self.world_size,
1503-
collate_fn=self.data_collator, # _get_collator_with_removed_columns
1504-
num_workers=self.args.dataloader_num_workers,
1505-
)
1517+
if self.args.distributed_dataloader:
1518+
return _DataLoader(
1519+
test_dataset,
1520+
batch_size=self.args.per_device_eval_batch_size * self.world_size,
1521+
collate_fn=self.data_collator, # _get_collator_with_removed_columns
1522+
num_workers=self.args.dataloader_num_workers,
1523+
eval=True,
1524+
)
1525+
else:
1526+
return _DataLoader(
1527+
test_dataset,
1528+
batch_size=self.args.per_device_eval_batch_size * self.world_size,
1529+
collate_fn=self.data_collator, # _get_collator_with_removed_columns
1530+
num_workers=self.args.dataloader_num_workers,
1531+
)
15061532

15071533
test_sampler = self._get_eval_sampler(test_dataset)
15081534

15091535
if self.args.distributed_dataloader:
15101536
logger.info("Test using DistDataLoader.")
15111537

1512-
# We use the same batch_size as for eval.
1513-
return _DataLoader(
1514-
test_dataset,
1515-
batch_sampler=test_sampler,
1516-
collate_fn=self.data_collator,
1517-
drop_last=self.args.dataloader_drop_last,
1518-
)
1538+
# We use the same batch_size as for eval.
1539+
return _DataLoader(
1540+
test_dataset,
1541+
batch_sampler=test_sampler,
1542+
collate_fn=self.data_collator,
1543+
drop_last=self.args.dataloader_drop_last,
1544+
eval=True,
1545+
)
1546+
else:
1547+
return _DataLoader(
1548+
test_dataset,
1549+
batch_sampler=test_sampler,
1550+
collate_fn=self.data_collator,
1551+
drop_last=self.args.dataloader_drop_last,
1552+
)
15191553

15201554
def create_optimizer_and_scheduler(self, num_training_steps: int):
15211555
"""

0 commit comments

Comments
 (0)