@@ -140,21 +140,25 @@ def __init__(self, args):
140
140
self ._process_master_weight = None
141
141
self ._process_optimizer_weight = None
142
142
self ._lock = None
143
- self ._shared_save_path = None
144
143
self ._shared_save_model_flag = None
145
144
self ._shared_save_master_weight_flag = None
146
145
self ._shared_save_optimizer_flag = None
147
146
148
147
if "async_save" in self .args .unified_checkpoint_config :
149
148
self ._lock = multiprocessing .Lock ()
150
149
self ._shared_save_model_path = multiprocessing .Array ("c" , 100000 )
150
+ self ._shared_save_model_signal_path = multiprocessing .Array ("c" , 100000 )
151
151
self ._shared_save_master_weight_path = multiprocessing .Array ("c" , 100000 )
152
+ self ._shared_save_master_weight_signal_path = multiprocessing .Array ("c" , 100000 )
152
153
self ._shared_save_optimizer_path = multiprocessing .Array ("c" , 100000 )
154
+ self ._shared_save_optimizer_signal_path = multiprocessing .Array ("c" , 100000 )
153
155
self ._shared_save_model_flag = multiprocessing .Array ("i" , 1 )
154
156
self ._shared_save_master_weight_flag = multiprocessing .Array ("i" , 1 )
155
157
self ._shared_save_optimizer_flag = multiprocessing .Array ("i" , 1 )
156
158
157
- def _file_save_async_or_sync (self , state_dict , path , is_sync = True , state_dict_type = "model_weight" ):
159
+ def _file_save_async_or_sync (
160
+ self , state_dict , path , signal_path = None , is_sync = True , state_dict_type = "model_weight"
161
+ ):
158
162
if is_sync :
159
163
for k in list (state_dict .keys ()):
160
164
if isinstance (state_dict [k ], paddle .Tensor ):
@@ -169,6 +173,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
169
173
meta_dict = self ._meta_dict_model
170
174
shared_save_flag = self ._shared_save_model_flag
171
175
shared_save_path = self ._shared_save_model_path
176
+ shared_save_signal_path = self ._shared_save_model_signal_path
172
177
if self ._process_model_weight is None :
173
178
self ._process_model_weight = multiprocessing .Process (
174
179
target = self ._save_file_async_in_process ,
@@ -177,6 +182,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
177
182
self ._shm_model_weight .name ,
178
183
self ._shared_save_model_flag ,
179
184
self ._shared_save_model_path ,
185
+ self ._shared_save_model_signal_path ,
180
186
self ._lock ,
181
187
state_dict_type ,
182
188
self .global_rank ,
@@ -191,6 +197,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
191
197
meta_dict = self ._meta_dict_master_weight
192
198
shared_save_flag = self ._shared_save_master_weight_flag
193
199
shared_save_path = self ._shared_save_master_weight_path
200
+ shared_save_signal_path = self ._shared_save_master_weight_signal_path
194
201
if self ._process_master_weight is None :
195
202
self ._process_master_weight = multiprocessing .Process (
196
203
target = self ._save_file_async_in_process ,
@@ -199,6 +206,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
199
206
self ._shm_master_weight .name ,
200
207
self ._shared_save_master_weight_flag ,
201
208
self ._shared_save_master_weight_path ,
209
+ self ._shared_save_master_weight_signal_path ,
202
210
self ._lock ,
203
211
"model_weight"
204
212
if "skip_save_model_weight" in self .args .unified_checkpoint_config
@@ -215,6 +223,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
215
223
meta_dict = self ._meta_dict_optim
216
224
shared_save_flag = self ._shared_save_optimizer_flag
217
225
shared_save_path = self ._shared_save_optimizer_path
226
+ shared_save_signal_path = self ._shared_save_optimizer_signal_path
218
227
if self ._process_optimizer_weight is None :
219
228
self ._process_optimizer_weight = multiprocessing .Process (
220
229
target = self ._save_file_async_in_process ,
@@ -223,6 +232,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
223
232
self ._shm_optimizer_weight .name ,
224
233
self ._shared_save_optimizer_flag ,
225
234
self ._shared_save_optimizer_path ,
235
+ self ._shared_save_optimizer_signal_path ,
226
236
self ._lock ,
227
237
state_dict_type ,
228
238
self .global_rank ,
@@ -238,6 +248,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
238
248
logger .info (f"Wait for the previous save process to finish saving { state_dict_type } " )
239
249
# only save model weight or save master weight, we enter this loop.
240
250
self ._reset_and_update (shared_save_path , path )
251
+ self ._reset_and_update (shared_save_signal_path , signal_path )
241
252
_traverse_copy_to_shm (state_dict , meta_dict , shm_state_dict .buf )
242
253
with self ._lock :
243
254
shared_save_flag [0 ] = 1
@@ -248,6 +259,7 @@ def _save_file_async_in_process(
248
259
shm_name ,
249
260
shared_save_flag ,
250
261
shared_save_path ,
262
+ shared_save_signal_path ,
251
263
lock ,
252
264
state_dict_type ,
253
265
global_rank ,
@@ -261,11 +273,13 @@ def _save_file_async_in_process(
261
273
continue
262
274
if flag_value == 1 : # need to save
263
275
path = shared_save_path [:].decode ("utf-8" ).rstrip ("\x00 " )
276
+ signal_path = shared_save_signal_path [:].decode ("utf-8" ).rstrip ("\x00 " )
264
277
logger .info (f"Start to async save { path } " )
265
278
state_dict = _read_state_dict_from_shm (meta_dict , shm ) # numpy array
266
279
safe_save_file (state_dict , path , {"format" : "np" })
267
280
del state_dict
268
- saved_signal_path = os .path .join (os .path .dirname (path ), f".{ state_dict_type } .done.{ global_rank } " )
281
+ os .makedirs (signal_path , exist_ok = True )
282
+ saved_signal_path = os .path .join (signal_path , f".{ state_dict_type } .done.{ global_rank } " )
269
283
paddle .save (global_rank , saved_signal_path )
270
284
with lock :
271
285
shared_save_flag [0 ] = 0
@@ -280,7 +294,7 @@ def _reset_and_update(self, shared_array, new_value):
280
294
encoded_value = new_value .encode ("utf-8" )
281
295
shared_array [: len (encoded_value )] = encoded_value
282
296
283
- def save_unified_checkpoint (self , model , optimizer , output_dir ):
297
+ def save_unified_checkpoint (self , model , optimizer , output_dir , signal_dir = None ):
284
298
"""save unified checkpoint
285
299
286
300
Args:
@@ -317,6 +331,8 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
317
331
318
332
save_directory = output_dir
319
333
os .makedirs (save_directory , exist_ok = True )
334
+ if signal_dir is not None :
335
+ os .makedirs (signal_dir , exist_ok = True ) # only for async save
320
336
321
337
# save model weights
322
338
if not skip_save_model_weight :
@@ -329,6 +345,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
329
345
self ._file_save_async_or_sync (
330
346
state_dict ,
331
347
path = os .path .join (save_directory , shard_file ),
348
+ signal_path = signal_dir ,
332
349
is_sync = is_sync_save ,
333
350
state_dict_type = "model_weight" ,
334
351
)
@@ -397,7 +414,7 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str)
397
414
if self .args .dataset_rank == 0 or self .args .use_expert_parallel :
398
415
load_unified_checkpoint_locally (self .args , model , resume_from_checkpoint , safe_serialization = True )
399
416
400
- def save_non_merge_optimizer (self , model , optimizer , output_dir ):
417
+ def save_non_merge_optimizer (self , model , optimizer , output_dir , signal_dir ):
401
418
paddle .device .cuda .empty_cache ()
402
419
optim_state_dict = nested_copy (optimizer .state_dict ())
403
420
master_weights = None
@@ -456,12 +473,14 @@ def save_non_merge_optimizer(self, model, optimizer, output_dir):
456
473
self ._file_save_async_or_sync (
457
474
optim_state_dict ,
458
475
path = os .path .join (output_dir , optimizer_name ),
476
+ signal_path = signal_dir ,
459
477
is_sync = is_sync_save ,
460
478
state_dict_type = "optimizer_weight" ,
461
479
)
462
480
self ._file_save_async_or_sync (
463
481
master_weights ,
464
482
path = os .path .join (output_dir , master_weights_name ),
483
+ signal_path = signal_dir ,
465
484
is_sync = is_sync_save ,
466
485
state_dict_type = "master_weight" ,
467
486
)
@@ -511,22 +530,23 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint):
511
530
512
531
return returned_optim_state_dict
513
532
514
- def save_unified_optimizer (self , model , optimizer , output_dir ):
533
+ def save_unified_optimizer (self , model , optimizer , output_dir , signal_dir ):
515
534
"""save unified optimizer
516
535
517
536
Args:
518
537
model (PretrainedModel): model used to get key mapping.
519
538
optimizer (Optimizer): optimizer to save
520
539
output_dir (str): Save directory.
540
+ signal_dir (str): Asynchronous saving signal directory.
521
541
522
542
"""
523
543
524
544
if "ignore_merge_optimizer" in self .args .unified_checkpoint_config :
525
- self .save_non_merge_optimizer (model , optimizer , output_dir )
545
+ self .save_non_merge_optimizer (model , optimizer , output_dir , signal_dir )
526
546
return
527
547
528
548
if paddle .distributed .get_world_size () <= 1 :
529
- self .save_single_card_optimizer (model , optimizer , output_dir )
549
+ self .save_single_card_optimizer (model , optimizer , output_dir ) # no need to save signal
530
550
return
531
551
532
552
# Split into naive optimizer params and master weights.
@@ -542,20 +562,24 @@ def save_unified_optimizer(self, model, optimizer, output_dir):
542
562
543
563
save_directory = output_dir
544
564
os .makedirs (save_directory , exist_ok = True )
565
+ if signal_dir is not None :
566
+ os .makedirs (signal_dir , exist_ok = True )
545
567
546
568
is_sync_save = True
547
569
if "async_save" in self .args .unified_checkpoint_config :
548
570
is_sync_save = False
549
571
self ._file_save_async_or_sync (
550
572
optim_state_dict ,
551
573
path = os .path .join (save_directory , shard_optim_file ),
574
+ signal_path = signal_dir ,
552
575
is_sync = is_sync_save ,
553
576
state_dict_type = "optimizer_weight" ,
554
577
)
555
578
if master_weight_state_dict is not None :
556
579
self ._file_save_async_or_sync (
557
580
master_weight_state_dict ,
558
581
path = os .path .join (save_directory , shard_master_weight_file ),
582
+ signal_path = signal_dir ,
559
583
is_sync = is_sync_save ,
560
584
state_dict_type = "master_weight" ,
561
585
)
0 commit comments