12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import numpy as np
16
15
import paddle
17
16
from paddle .distributed import fleet
18
17
19
18
from paddlenlp .utils .log import logger
20
-
21
- _MAX_DATA_DIM = 64
19
+ from paddlenlp .utils .nested import (
20
+ nested_broadcast_tensor ,
21
+ nested_copy_place ,
22
+ nested_empty_tensor ,
23
+ nested_reduce_tensor ,
24
+ )
22
25
23
26
24
27
class DummyDataset (paddle .io .Dataset ):
@@ -53,6 +56,7 @@ def __init__(
53
56
timeout = 0 ,
54
57
worker_init_fn = None ,
55
58
persistent_workers = False ,
59
+ eval = False ,
56
60
):
57
61
58
62
if dataset is None :
@@ -62,12 +66,15 @@ def __init__(
62
66
super ().__init__ (dataset = dataset , batch_sampler = batch_sampler , collate_fn = collate_fn , num_workers = num_workers )
63
67
64
68
self ._hcg = fleet .get_hybrid_communicate_group ()
69
+ self .eval = eval
65
70
66
71
# Init pp data comm group.
67
72
if self ._hcg .get_pipe_parallel_world_size () > 1 :
68
73
self ._pp_data_group = self ._init_dataloader_comm_group ()
74
+ self ._pp_group = self ._hcg .get_pipe_parallel_group ()
69
75
else :
70
76
self ._pp_data_group = None
77
+ self ._pp_group = None
71
78
72
79
self .mp_group = self ._hcg .get_model_parallel_group ()
73
80
self .mp_rank = self ._hcg .get_model_parallel_rank ()
@@ -78,10 +85,6 @@ def __init__(
78
85
sharding_rank = self ._hcg .get_sharding_parallel_rank ()
79
86
self ._need_data = (self .mp_rank == 0 ) and (self .pp_rank == 0 )
80
87
81
- # When needed other data types, we can modify dtype_list.
82
- self .dtype_list = [paddle .int64 , paddle .float32 , paddle .int32 ]
83
- self ._data_keys_list , self ._data_keys_size = None , None
84
-
85
88
if self ._need_data :
86
89
self ._dataloader = paddle .io .DataLoader (
87
90
dataset ,
@@ -127,7 +130,6 @@ def _init_dataloader_comm_group(self):
127
130
parallel_groups = topo .get_comm_list ("pipe" )
128
131
129
132
for group in parallel_groups :
130
- # only first rank and last rank
131
133
ranks = [group [0 ], group [- 1 ]]
132
134
comm_group = paddle .distributed .new_group (ranks = ranks )
133
135
if paddle .distributed .get_rank () in ranks :
@@ -137,127 +139,70 @@ def _init_dataloader_comm_group(self):
137
139
def __iter__ (self ):
138
140
return self
139
141
140
- def __next__ (self ):
141
- data_keys_size = [0 for i in range (len (self .dtype_list ))]
142
- if self ._need_data :
143
- data = next (self ._dataloader_iter )
144
- data_keys = list (data .keys ())
145
-
146
- for key in data_keys :
147
- if data [key ].dtype not in self .dtype_list :
148
- raise ValueError (
149
- f"Dist dataloader requires dtype as `int64`, `float32` or `int32` currently, but got: { data [key ].dtype } "
142
+ def _broadcast_data (self , data ):
143
+ process_rank = paddle .distributed .get_rank ()
144
+ if self .mp_group .nranks > 1 :
145
+ if process_rank == self .mp_src_rank :
146
+ fake_data = [nested_reduce_tensor (data )]
147
+ else :
148
+ if data is not None :
149
+ logger .warning (
150
+ f"Your local rank { paddle .distributed .get_rank ()} are forbidden to have a state_dict."
150
151
)
151
-
152
- data_list , data_keys_list = [], []
153
- for i , dtype in enumerate (self .dtype_list ):
154
- data_list .append ([data [key ] for key in data_keys if data [key ].dtype == dtype ])
155
- data_keys_list .append ([key for key in data_keys if data [key ].dtype == dtype ])
156
- data_keys_size = [len (keys ) for keys in data_keys_list ]
157
-
158
- # Broadcast data keys size.
159
- if self ._data_keys_size is None :
160
- if self .mp_group .nranks > 1 and self .pp_rank == 0 :
161
- paddle .distributed .broadcast_object_list (data_keys_size , src = self .mp_src_rank , group = self .mp_group )
162
- if self ._pp_data_group is not None :
163
- paddle .distributed .broadcast_object_list (
164
- data_keys_size , src = self ._pp_data_group .ranks [0 ], group = self ._pp_data_group
165
- )
166
- self ._data_keys_size = data_keys_size
167
-
168
- if not self ._need_data :
169
- data_keys_list = [[None for i in range (keys_size )] for keys_size in self ._data_keys_size ]
170
-
171
- # Broadcast data keys name.
172
- if self ._data_keys_list is None :
173
- if self .mp_group .nranks > 1 and self .pp_rank == 0 :
174
- paddle .distributed .broadcast_object_list (data_keys_list , src = self .mp_src_rank , group = self .mp_group )
175
- if self ._pp_data_group is not None :
176
- paddle .distributed .broadcast_object_list (
177
- data_keys_list , src = self ._pp_data_group .ranks [0 ], group = self ._pp_data_group
178
- )
179
- self ._data_keys_list = data_keys_list
180
-
181
- # Broadcast data.
182
- if not self ._need_data :
183
- data_list = [[None for i in range (keys_size )] for keys_size in self ._data_keys_size ]
184
-
185
- if self .mp_group .nranks > 1 and self .pp_rank == 0 :
186
- for i , dtype in enumerate (self .dtype_list ):
187
- if self ._data_keys_size [i ] > 0 :
188
- data_list [i ] = broadcast_data_list (
189
- data_list [i ], dtype , self .mp_rank , self .mp_group , self .mp_src_rank
152
+ fake_data = [None ]
153
+ if self ._pp_group is not None :
154
+ if process_rank == self ._pp_group .ranks [0 ]:
155
+ fake_data = [nested_reduce_tensor (data )]
156
+ else :
157
+ if data is not None :
158
+ logger .warning (
159
+ f"Your local rank { paddle .distributed .get_rank ()} are forbidden to have a state_dict."
190
160
)
191
-
192
- if self ._pp_data_group is not None :
193
- # Note(daisimng): In last stage of pp, we don't need input_ids.
194
- # It will be removed in future.
195
- for i , dtype in enumerate (self .dtype_list ):
196
- if self ._data_keys_size [i ] > 0 :
197
- data_list [i ] = broadcast_data_list (
198
- data_list [i ],
199
- dtype ,
200
- self .pp_rank ,
201
- self ._pp_data_group ,
202
- self ._pp_data_group .ranks [0 ],
203
- )
204
-
205
- out_data = {}
206
- for keys , datas in zip (self ._data_keys_list , data_list ):
207
- out_data .update ([(k , d ) for k , d in zip (keys , datas )])
208
-
209
- return out_data
210
-
211
-
212
- def broadcast_data_list (data_list , datatype , comm_rank = 0 , comm_group = None , src_rank = 0 ):
213
- """
214
- Broadcast data from src_rank to all ranks in comm_group.
215
- """
216
- # Move to GPU and broadcast.
217
- size_cpu = []
218
- if comm_rank == 0 :
219
- for data in data_list :
220
- size_cpu .append (len (data .shape ))
221
- size_cpu += data .shape
222
- size_cpu = size_cpu + [0 ] * (_MAX_DATA_DIM - len (size_cpu ))
223
- size_cuda = paddle .to_tensor (size_cpu )
224
- paddle .distributed .broadcast (size_cuda , src_rank , group = comm_group ).wait ()
225
-
226
- size_cpu = size_cuda .tolist ()
227
- i = 0
228
- numel = 0
229
- sizes = []
230
- while size_cpu [i ] > 0 :
231
- rank = size_cpu [i ]
232
- this_size = size_cpu [i + 1 : i + 1 + rank ]
233
- numel += int (np .prod (this_size ))
234
- sizes .append (this_size )
235
- i += rank + 1
236
-
237
- if comm_rank == 0 :
238
- assert data .dtype == datatype , "input has data type {} which " "is different than {}" .format (
239
- data .dtype , datatype
240
- )
241
- if paddle .is_compiled_with_cuda ():
242
- data_b = paddle .concat ([d .cuda ().reshape ([- 1 ]) for d in data_list ], 0 )
161
+ fake_data = [None ]
162
+ if self .mp_group .nranks > 1 and self .pp_rank == 0 :
163
+ paddle .distributed .broadcast_object_list (
164
+ fake_data ,
165
+ src = self .mp_src_rank ,
166
+ group = self .mp_group ,
167
+ )
168
+ if self ._pp_group is not None :
169
+ paddle .distributed .broadcast_object_list (
170
+ fake_data ,
171
+ src = self ._pp_group .ranks [0 ],
172
+ group = self ._pp_group ,
173
+ )
243
174
else :
244
- data_b = paddle . concat ([ d . reshape ([ - 1 ]) for d in data_list ], 0 )
175
+ fake_data = [ None ]
245
176
246
- assert numel == sum ([d .numel ().item () for d in data_list ]), (numel , [d .numel ().item () for d in data_list ])
247
- else :
248
- if paddle .is_compiled_with_cuda ():
249
- data_b = paddle .empty ([numel ], dtype = datatype ).cuda ()
250
- else :
251
- data_b = paddle .empty ([numel ], dtype = datatype )
177
+ fake_data = fake_data [0 ]
178
+ if fake_data is None :
179
+ raise StopIteration
252
180
253
- # Broadcast
254
- paddle .distributed .broadcast (data_b , src_rank , group = comm_group ).wait ()
181
+ dst_pp_group = self ._pp_group if self .eval else self ._pp_data_group
182
+ if self .mp_group .nranks > 1 :
183
+ if process_rank != self .mp_src_rank :
184
+ data = nested_empty_tensor (fake_data )
185
+ if dst_pp_group is not None :
186
+ if process_rank != dst_pp_group .ranks [0 ]:
187
+ data = nested_empty_tensor (fake_data )
255
188
256
- ret = []
257
- offset = 0
258
- for size in sizes :
259
- numel = int (np .prod (size ))
260
- ret .append (data_b [offset : offset + numel ].reshape (size ))
261
- offset += numel
189
+ if self .mp_group .nranks > 1 and self .pp_rank == 0 :
190
+ data = nested_broadcast_tensor (data , src = self .mp_src_rank , group = self .mp_group )
191
+ if dst_pp_group is not None :
192
+ data = nested_broadcast_tensor (data , src = dst_pp_group .ranks [0 ], group = dst_pp_group )
193
+ # for pp1 - pp_{n-1}, Paddle need to recevie empty dict for pipeline parallel.
194
+ if data is None :
195
+ data = {}
262
196
263
- return ret
197
+ return data
198
+
199
+ def __next__ (self ):
200
+ data = None
201
+ if self ._need_data :
202
+ try :
203
+ data = next (self ._dataloader_iter )
204
+ data = nested_copy_place (data , place = paddle .framework ._current_expected_place ())
205
+ except :
206
+ pass
207
+ data = self ._broadcast_data (data )
208
+ return data
0 commit comments