@@ -136,21 +136,25 @@ def __init__(self, args):
136
136
self ._process_master_weight = None
137
137
self ._process_optimizer_weight = None
138
138
self ._lock = None
139
- self ._shared_save_path = None
140
139
self ._shared_save_model_flag = None
141
140
self ._shared_save_master_weight_flag = None
142
141
self ._shared_save_optimizer_flag = None
143
142
144
143
if "async_save" in self .args .unified_checkpoint_config :
145
144
self ._lock = multiprocessing .Lock ()
146
145
self ._shared_save_model_path = multiprocessing .Array ("c" , 100000 )
146
+ self ._shared_save_model_signal_path = multiprocessing .Array ("c" , 100000 )
147
147
self ._shared_save_master_weight_path = multiprocessing .Array ("c" , 100000 )
148
+ self ._shared_save_master_weight_signal_path = multiprocessing .Array ("c" , 100000 )
148
149
self ._shared_save_optimizer_path = multiprocessing .Array ("c" , 100000 )
150
+ self ._shared_save_optimizer_signal_path = multiprocessing .Array ("c" , 100000 )
149
151
self ._shared_save_model_flag = multiprocessing .Array ("i" , 1 )
150
152
self ._shared_save_master_weight_flag = multiprocessing .Array ("i" , 1 )
151
153
self ._shared_save_optimizer_flag = multiprocessing .Array ("i" , 1 )
152
154
153
- def _file_save_async_or_sync (self , state_dict , path , is_sync = True , state_dict_type = "model_weight" ):
155
+ def _file_save_async_or_sync (
156
+ self , state_dict , path , signal_path = None , is_sync = True , state_dict_type = "model_weight"
157
+ ):
154
158
if is_sync :
155
159
for k in list (state_dict .keys ()):
156
160
if isinstance (state_dict [k ], paddle .Tensor ):
@@ -165,6 +169,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
165
169
meta_dict = self ._meta_dict_model
166
170
shared_save_flag = self ._shared_save_model_flag
167
171
shared_save_path = self ._shared_save_model_path
172
+ shared_save_signal_path = self ._shared_save_model_signal_path
168
173
if self ._process_model_weight is None :
169
174
self ._process_model_weight = multiprocessing .Process (
170
175
target = self ._save_file_async_in_process ,
@@ -173,12 +178,14 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
173
178
self ._shm_model_weight .name ,
174
179
self ._shared_save_model_flag ,
175
180
self ._shared_save_model_path ,
181
+ self ._shared_save_model_signal_path ,
176
182
self ._lock ,
177
183
state_dict_type ,
178
184
self .global_rank ,
179
185
),
180
186
)
181
187
self ._process_model_weight .start ()
188
+ process = self ._process_model_weight
182
189
elif state_dict_type == "master_weight" :
183
190
if self ._shm_master_weight is None :
184
191
self ._meta_dict_master_weight , buffer_size = create_meta_dict (state_dict )
@@ -187,6 +194,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
187
194
meta_dict = self ._meta_dict_master_weight
188
195
shared_save_flag = self ._shared_save_master_weight_flag
189
196
shared_save_path = self ._shared_save_master_weight_path
197
+ shared_save_signal_path = self ._shared_save_master_weight_signal_path
190
198
if self ._process_master_weight is None :
191
199
self ._process_master_weight = multiprocessing .Process (
192
200
target = self ._save_file_async_in_process ,
@@ -195,6 +203,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
195
203
self ._shm_master_weight .name ,
196
204
self ._shared_save_master_weight_flag ,
197
205
self ._shared_save_master_weight_path ,
206
+ self ._shared_save_master_weight_signal_path ,
198
207
self ._lock ,
199
208
"model_weight"
200
209
if "skip_save_model_weight" in self .args .unified_checkpoint_config
@@ -203,6 +212,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
203
212
),
204
213
)
205
214
self ._process_master_weight .start ()
215
+ process = self ._process_master_weight
206
216
elif state_dict_type == "optimizer_weight" :
207
217
if self ._shm_optimizer_weight is None :
208
218
self ._meta_dict_optim , buffer_size = create_meta_dict (state_dict )
@@ -211,6 +221,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
211
221
meta_dict = self ._meta_dict_optim
212
222
shared_save_flag = self ._shared_save_optimizer_flag
213
223
shared_save_path = self ._shared_save_optimizer_path
224
+ shared_save_signal_path = self ._shared_save_optimizer_signal_path
214
225
if self ._process_optimizer_weight is None :
215
226
self ._process_optimizer_weight = multiprocessing .Process (
216
227
target = self ._save_file_async_in_process ,
@@ -219,21 +230,26 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
219
230
self ._shm_optimizer_weight .name ,
220
231
self ._shared_save_optimizer_flag ,
221
232
self ._shared_save_optimizer_path ,
233
+ self ._shared_save_optimizer_signal_path ,
222
234
self ._lock ,
223
235
state_dict_type ,
224
236
self .global_rank ,
225
237
),
226
238
)
227
239
self ._process_optimizer_weight .start ()
240
+ process = self ._process_optimizer_weight
228
241
229
242
while True : # wait until no process is saving.
230
243
flag_value = shared_save_flag [0 ]
231
244
if flag_value == 0 :
232
245
break
246
+ if not process .is_alive ():
247
+ raise RuntimeError (f"The process that saves { state_dict_type } has been killed unexpectedly." )
233
248
time .sleep (0.5 )
234
249
logger .info (f"Wait for the previous save process to finish saving { state_dict_type } " )
235
250
# only save model weight or save master weight, we enter this loop.
236
251
self ._reset_and_update (shared_save_path , path )
252
+ self ._reset_and_update (shared_save_signal_path , signal_path )
237
253
_traverse_copy_to_shm (state_dict , meta_dict , shm_state_dict .buf )
238
254
with self ._lock :
239
255
shared_save_flag [0 ] = 1
@@ -244,6 +260,7 @@ def _save_file_async_in_process(
244
260
shm_name ,
245
261
shared_save_flag ,
246
262
shared_save_path ,
263
+ shared_save_signal_path ,
247
264
lock ,
248
265
state_dict_type ,
249
266
global_rank ,
@@ -257,11 +274,12 @@ def _save_file_async_in_process(
257
274
continue
258
275
if flag_value == 1 : # need to save
259
276
path = shared_save_path [:].decode ("utf-8" ).rstrip ("\x00 " )
277
+ signal_path = shared_save_signal_path [:].decode ("utf-8" ).rstrip ("\x00 " )
260
278
logger .info (f"Start to async save { path } " )
261
279
state_dict = _read_state_dict_from_shm (meta_dict , shm ) # numpy array
262
280
safe_save_file (state_dict , path , {"format" : "np" })
263
281
del state_dict
264
- saved_signal_path = os .path .join (os . path . dirname ( path ) , f".{ state_dict_type } .done.{ global_rank } " )
282
+ saved_signal_path = os .path .join (signal_path , f".{ state_dict_type } .done.{ global_rank } " )
265
283
paddle .save (global_rank , saved_signal_path )
266
284
with lock :
267
285
shared_save_flag [0 ] = 0
@@ -276,7 +294,7 @@ def _reset_and_update(self, shared_array, new_value):
276
294
encoded_value = new_value .encode ("utf-8" )
277
295
shared_array [: len (encoded_value )] = encoded_value
278
296
279
- def save_unified_checkpoint (self , model , optimizer , output_dir ):
297
+ def save_unified_checkpoint (self , model , optimizer , output_dir , signal_dir = None ):
280
298
"""save unified checkpoint
281
299
282
300
Args:
@@ -313,6 +331,8 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
313
331
314
332
save_directory = output_dir
315
333
os .makedirs (save_directory , exist_ok = True )
334
+ if signal_dir is not None :
335
+ os .makedirs (signal_dir , exist_ok = True ) # only for async save
316
336
317
337
# save model weights
318
338
if not skip_save_model_weight :
@@ -325,6 +345,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
325
345
self ._file_save_async_or_sync (
326
346
state_dict ,
327
347
path = os .path .join (save_directory , shard_file ),
348
+ signal_path = signal_dir ,
328
349
is_sync = is_sync_save ,
329
350
state_dict_type = "model_weight" ,
330
351
)
@@ -393,7 +414,7 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str)
393
414
if self .args .dataset_rank == 0 :
394
415
load_unified_checkpoint_locally (self .args , model , resume_from_checkpoint , safe_serialization = True )
395
416
396
- def save_non_merge_optimizer (self , model , optimizer , output_dir ):
417
+ def save_non_merge_optimizer (self , model , optimizer , output_dir , signal_dir ):
397
418
paddle .device .cuda .empty_cache ()
398
419
optim_state_dict = nested_copy (optimizer .state_dict ())
399
420
master_weights = None
@@ -432,12 +453,14 @@ def save_non_merge_optimizer(self, model, optimizer, output_dir):
432
453
self ._file_save_async_or_sync (
433
454
optim_state_dict ,
434
455
path = os .path .join (output_dir , optimizer_name ),
456
+ signal_path = signal_dir ,
435
457
is_sync = is_sync_save ,
436
458
state_dict_type = "optimizer_weight" ,
437
459
)
438
460
self ._file_save_async_or_sync (
439
461
master_weights ,
440
462
path = os .path .join (output_dir , master_weights_name ),
463
+ signal_path = signal_dir ,
441
464
is_sync = is_sync_save ,
442
465
state_dict_type = "master_weight" ,
443
466
)
@@ -484,22 +507,23 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint):
484
507
485
508
return returned_optim_state_dict
486
509
487
- def save_unified_optimizer (self , model , optimizer , output_dir ):
510
+ def save_unified_optimizer (self , model , optimizer , output_dir , signal_dir ):
488
511
"""save unified optimizer
489
512
490
513
Args:
491
514
model (PretrainedModel): model used to get key mapping.
492
515
optimizer (Optimizer): optimizer to save
493
516
output_dir (str): Save directory.
517
+ signal_dir (str): Asynchronous saving signal directory.
494
518
495
519
"""
496
520
497
521
if "ignore_merge_optimizer" in self .args .unified_checkpoint_config :
498
- self .save_non_merge_optimizer (model , optimizer , output_dir )
522
+ self .save_non_merge_optimizer (model , optimizer , output_dir , signal_dir )
499
523
return
500
524
501
525
if paddle .distributed .get_world_size () <= 1 :
502
- self .save_single_card_optimizer (model , optimizer , output_dir )
526
+ self .save_single_card_optimizer (model , optimizer , output_dir ) # no need to save signal
503
527
return
504
528
505
529
# Split into naive optimizer params and master weights.
@@ -515,20 +539,24 @@ def save_unified_optimizer(self, model, optimizer, output_dir):
515
539
516
540
save_directory = output_dir
517
541
os .makedirs (save_directory , exist_ok = True )
542
+ if signal_dir is not None :
543
+ os .makedirs (signal_dir , exist_ok = True )
518
544
519
545
is_sync_save = True
520
546
if "async_save" in self .args .unified_checkpoint_config :
521
547
is_sync_save = False
522
548
self ._file_save_async_or_sync (
523
549
optim_state_dict ,
524
550
path = os .path .join (save_directory , shard_optim_file ),
551
+ signal_path = signal_dir ,
525
552
is_sync = is_sync_save ,
526
553
state_dict_type = "optimizer_weight" ,
527
554
)
528
555
if master_weight_state_dict is not None :
529
556
self ._file_save_async_or_sync (
530
557
master_weight_state_dict ,
531
558
path = os .path .join (save_directory , shard_master_weight_file ),
559
+ signal_path = signal_dir ,
532
560
is_sync = is_sync_save ,
533
561
state_dict_type = "master_weight" ,
534
562
)
@@ -716,14 +744,20 @@ def unlink_shared_memory(self):
716
744
717
745
if self ._shared_save_model_flag is not None :
718
746
while self ._shared_save_model_flag [0 ] > 0 : # async process is saving
747
+ if not self ._process_model_weight .is_alive ():
748
+ raise RuntimeError ("The process that saves model_weight has been killed unexpectedly." )
719
749
time .sleep (0.5 )
720
750
self ._shared_save_model_flag [0 ] = - 1
721
751
if self ._shared_save_master_weight_flag is not None :
722
752
while self ._shared_save_master_weight_flag [0 ] > 0 :
753
+ if not self ._process_master_weight .is_alive ():
754
+ raise RuntimeError ("The process that saves master_weight has been killed unexpectedly." )
723
755
time .sleep (0.5 )
724
756
self ._shared_save_master_weight_flag [0 ] = - 1
725
757
if self ._shared_save_optimizer_flag is not None :
726
758
while self ._shared_save_optimizer_flag [0 ] > 0 :
759
+ if not self ._process_optimizer_weight .is_alive ():
760
+ raise RuntimeError ("The process that saves optimizer_weight has been killed unexpectedly." )
727
761
time .sleep (0.5 )
728
762
self ._shared_save_optimizer_flag [0 ] = - 1
729
763
@@ -740,7 +774,8 @@ def unlink_shared_memory(self):
740
774
self ._shm_optimizer_weight .unlink ()
741
775
self ._shm_optimizer_weight = None
742
776
743
- dist .barrier ()
777
+ if paddle .distributed .get_world_size () > 1 :
778
+ dist .barrier ()
744
779
745
780
746
781
def load_unified_checkpoint_locally (args , model , resume_from_checkpoint : str , safe_serialization = False ):
0 commit comments