Skip to content

Commit 84b4bf7

Browse files
committed
add first try
1 parent e7a6ac2 commit 84b4bf7

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

paddlenlp/data/dist_dataloader.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@ def __init__(
7171
# Init pp data comm group.
7272
if self._hcg.get_pipe_parallel_world_size() > 1:
7373
self._pp_data_group = self._init_dataloader_comm_group()
74+
self._pp_group = self._hcg.get_pipe_parallel_group()
7475
else:
7576
self._pp_data_group = None
77+
self._pp_group = None
7678

7779
self.mp_group = self._hcg.get_model_parallel_group()
7880
self.mp_rank = self._hcg.get_model_parallel_rank()
@@ -128,11 +130,7 @@ def _init_dataloader_comm_group(self):
128130
parallel_groups = topo.get_comm_list("pipe")
129131

130132
for group in parallel_groups:
131-
if not self.eval:
132-
# only first rank and last rank
133-
ranks = [group[0], group[-1]]
134-
else:
135-
ranks = group
133+
ranks = [group[0], group[-1]]
136134
comm_group = paddle.distributed.new_group(ranks=ranks)
137135
if paddle.distributed.get_rank() in ranks:
138136
parallel_comm_group = comm_group
@@ -152,8 +150,8 @@ def _broadcast_data(self, data):
152150
f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict."
153151
)
154152
fake_data = [None]
155-
if self._pp_data_group is not None:
156-
if process_rank == self._pp_data_group.ranks[0]:
153+
if self._pp_group is not None:
154+
if process_rank == self._pp_group.ranks[0]:
157155
fake_data = [nested_reduce_tensor(data)]
158156
else:
159157
if data is not None:
@@ -167,31 +165,34 @@ def _broadcast_data(self, data):
167165
src=self.mp_src_rank,
168166
group=self.mp_group,
169167
)
170-
if self._pp_data_group is not None:
168+
if self._pp_group is not None:
171169
paddle.distributed.broadcast_object_list(
172170
fake_data,
173-
src=self._pp_data_group.ranks[0],
174-
group=self._pp_data_group,
171+
src=self._pp_group.ranks[0],
172+
group=self._pp_group,
175173
)
176174
else:
177175
fake_data = [None]
178176

179177
fake_data = fake_data[0]
178+
if fake_data is None:
179+
raise StopIteration
180180

181+
dst_pp_group = self._pp_group if self.eval else self._pp_data_group
181182
if self.mp_group.nranks > 1:
182183
if process_rank != self.mp_src_rank:
183184
data = nested_empty_tensor(fake_data)
184-
if self._pp_data_group is not None:
185-
if process_rank != self._pp_data_group.ranks[0]:
185+
if dst_pp_group is not None:
186+
if process_rank != dst_pp_group.ranks[0]:
186187
data = nested_empty_tensor(fake_data)
187188

188189
if self.mp_group.nranks > 1 and self.pp_rank == 0:
189190
data = nested_broadcast_tensor(data, src=self.mp_src_rank, group=self.mp_group)
190-
if self._pp_data_group is not None:
191-
data = nested_broadcast_tensor(data, src=self._pp_data_group.ranks[0], group=self._pp_data_group)
192-
191+
if dst_pp_group is not None:
192+
data = nested_broadcast_tensor(data, src=dst_pp_group.ranks[0], group=dst_pp_group)
193+
# for pp1 - pp_{n-1}, Paddle need to recevie empty dict for pipeline parallel.
193194
if data is None:
194-
raise StopIteration
195+
data = {}
195196

196197
return data
197198

0 commit comments

Comments
 (0)