@@ -71,8 +71,10 @@ def __init__(
71
71
# Init pp data comm group.
72
72
if self ._hcg .get_pipe_parallel_world_size () > 1 :
73
73
self ._pp_data_group = self ._init_dataloader_comm_group ()
74
+ self ._pp_group = self ._hcg .get_pipe_parallel_group ()
74
75
else :
75
76
self ._pp_data_group = None
77
+ self ._pp_group = None
76
78
77
79
self .mp_group = self ._hcg .get_model_parallel_group ()
78
80
self .mp_rank = self ._hcg .get_model_parallel_rank ()
@@ -128,11 +130,7 @@ def _init_dataloader_comm_group(self):
128
130
parallel_groups = topo .get_comm_list ("pipe" )
129
131
130
132
for group in parallel_groups :
131
- if not self .eval :
132
- # only first rank and last rank
133
- ranks = [group [0 ], group [- 1 ]]
134
- else :
135
- ranks = group
133
+ ranks = [group [0 ], group [- 1 ]]
136
134
comm_group = paddle .distributed .new_group (ranks = ranks )
137
135
if paddle .distributed .get_rank () in ranks :
138
136
parallel_comm_group = comm_group
@@ -152,8 +150,8 @@ def _broadcast_data(self, data):
152
150
f"Your local rank { paddle .distributed .get_rank ()} are forbidden to have a state_dict."
153
151
)
154
152
fake_data = [None ]
155
- if self ._pp_data_group is not None :
156
- if process_rank == self ._pp_data_group .ranks [0 ]:
153
+ if self ._pp_group is not None :
154
+ if process_rank == self ._pp_group .ranks [0 ]:
157
155
fake_data = [nested_reduce_tensor (data )]
158
156
else :
159
157
if data is not None :
@@ -167,31 +165,34 @@ def _broadcast_data(self, data):
167
165
src = self .mp_src_rank ,
168
166
group = self .mp_group ,
169
167
)
170
- if self ._pp_data_group is not None :
168
+ if self ._pp_group is not None :
171
169
paddle .distributed .broadcast_object_list (
172
170
fake_data ,
173
- src = self ._pp_data_group .ranks [0 ],
174
- group = self ._pp_data_group ,
171
+ src = self ._pp_group .ranks [0 ],
172
+ group = self ._pp_group ,
175
173
)
176
174
else :
177
175
fake_data = [None ]
178
176
179
177
fake_data = fake_data [0 ]
178
+ if fake_data is None :
179
+ raise StopIteration
180
180
181
+ dst_pp_group = self ._pp_group if self .eval else self ._pp_data_group
181
182
if self .mp_group .nranks > 1 :
182
183
if process_rank != self .mp_src_rank :
183
184
data = nested_empty_tensor (fake_data )
184
- if self . _pp_data_group is not None :
185
- if process_rank != self . _pp_data_group .ranks [0 ]:
185
+ if dst_pp_group is not None :
186
+ if process_rank != dst_pp_group .ranks [0 ]:
186
187
data = nested_empty_tensor (fake_data )
187
188
188
189
if self .mp_group .nranks > 1 and self .pp_rank == 0 :
189
190
data = nested_broadcast_tensor (data , src = self .mp_src_rank , group = self .mp_group )
190
- if self . _pp_data_group is not None :
191
- data = nested_broadcast_tensor (data , src = self . _pp_data_group . ranks [0 ], group = self . _pp_data_group )
192
-
191
+ if dst_pp_group is not None :
192
+ data = nested_broadcast_tensor (data , src = dst_pp_group . ranks [0 ], group = dst_pp_group )
193
+ # for pp1 - pp_{n-1}, Paddle need to recevie empty dict for pipeline parallel.
193
194
if data is None :
194
- raise StopIteration
195
+ data = {}
195
196
196
197
return data
197
198
0 commit comments