@@ -189,6 +189,7 @@ def _file_save_async_or_sync(
189
189
),
190
190
)
191
191
self ._process_model_weight .start ()
192
+ process = self ._process_model_weight
192
193
elif state_dict_type == "master_weight" :
193
194
if self ._shm_master_weight is None :
194
195
self ._meta_dict_master_weight , buffer_size = create_meta_dict (state_dict )
@@ -215,6 +216,7 @@ def _file_save_async_or_sync(
215
216
),
216
217
)
217
218
self ._process_master_weight .start ()
219
+ process = self ._process_master_weight
218
220
elif state_dict_type == "optimizer_weight" :
219
221
if self ._shm_optimizer_weight is None :
220
222
self ._meta_dict_optim , buffer_size = create_meta_dict (state_dict )
@@ -239,11 +241,14 @@ def _file_save_async_or_sync(
239
241
),
240
242
)
241
243
self ._process_optimizer_weight .start ()
244
+ process = self ._process_optimizer_weight
242
245
243
246
while True : # wait until no process is saving.
244
247
flag_value = shared_save_flag [0 ]
245
248
if flag_value == 0 :
246
249
break
250
+ if not process .is_alive ():
251
+ raise RuntimeError (f"The process that saves { state_dict_type } has been killed unexpectedly." )
247
252
time .sleep (0.5 )
248
253
logger .info (f"Wait for the previous save process to finish saving { state_dict_type } " )
249
254
# only save model weight or save master weight, we enter this loop.
@@ -278,7 +283,6 @@ def _save_file_async_in_process(
278
283
state_dict = _read_state_dict_from_shm (meta_dict , shm ) # numpy array
279
284
safe_save_file (state_dict , path , {"format" : "np" })
280
285
del state_dict
281
- os .makedirs (signal_path , exist_ok = True )
282
286
saved_signal_path = os .path .join (signal_path , f".{ state_dict_type } .done.{ global_rank } " )
283
287
paddle .save (global_rank , saved_signal_path )
284
288
with lock :
@@ -771,14 +775,20 @@ def unlink_shared_memory(self):
771
775
772
776
if self ._shared_save_model_flag is not None :
773
777
while self ._shared_save_model_flag [0 ] > 0 : # async process is saving
778
+ if not self ._process_model_weight .is_alive ():
779
+ raise RuntimeError ("The process that saves model_weight has been killed unexpectedly." )
774
780
time .sleep (0.5 )
775
781
self ._shared_save_model_flag [0 ] = - 1
776
782
if self ._shared_save_master_weight_flag is not None :
777
783
while self ._shared_save_master_weight_flag [0 ] > 0 :
784
+ if not self ._process_master_weight .is_alive ():
785
+ raise RuntimeError ("The process that saves master_weight has been killed unexpectedly." )
778
786
time .sleep (0.5 )
779
787
self ._shared_save_master_weight_flag [0 ] = - 1
780
788
if self ._shared_save_optimizer_flag is not None :
781
789
while self ._shared_save_optimizer_flag [0 ] > 0 :
790
+ if not self ._process_optimizer_weight .is_alive ():
791
+ raise RuntimeError ("The process that saves optimizer_weight has been killed unexpectedly." )
782
792
time .sleep (0.5 )
783
793
self ._shared_save_optimizer_flag [0 ] = - 1
784
794
@@ -795,7 +805,8 @@ def unlink_shared_memory(self):
795
805
self ._shm_optimizer_weight .unlink ()
796
806
self ._shm_optimizer_weight = None
797
807
798
- dist .barrier ()
808
+ if paddle .distributed .get_world_size () > 1 :
809
+ dist .barrier ()
799
810
800
811
801
812
def load_unified_checkpoint_locally (args , model , resume_from_checkpoint : str , safe_serialization = False ):
0 commit comments