Skip to content

Commit 992063c

Browse files
committed
fix async save hang
1 parent 2b62766 commit 992063c

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def _file_save_async_or_sync(
189189
),
190190
)
191191
self._process_model_weight.start()
192+
process = self._process_model_weight
192193
elif state_dict_type == "master_weight":
193194
if self._shm_master_weight is None:
194195
self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict)
@@ -215,6 +216,7 @@ def _file_save_async_or_sync(
215216
),
216217
)
217218
self._process_master_weight.start()
219+
process = self._process_master_weight
218220
elif state_dict_type == "optimizer_weight":
219221
if self._shm_optimizer_weight is None:
220222
self._meta_dict_optim, buffer_size = create_meta_dict(state_dict)
@@ -239,11 +241,14 @@ def _file_save_async_or_sync(
239241
),
240242
)
241243
self._process_optimizer_weight.start()
244+
process = self._process_optimizer_weight
242245

243246
while True: # wait until no process is saving.
244247
flag_value = shared_save_flag[0]
245248
if flag_value == 0:
246249
break
250+
if not process.is_alive():
251+
raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.")
247252
time.sleep(0.5)
248253
logger.info(f"Wait for the previous save process to finish saving {state_dict_type}")
249254
# only save model weight or save master weight, we enter this loop.
@@ -278,7 +283,6 @@ def _save_file_async_in_process(
278283
state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array
279284
safe_save_file(state_dict, path, {"format": "np"})
280285
del state_dict
281-
os.makedirs(signal_path, exist_ok=True)
282286
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")
283287
paddle.save(global_rank, saved_signal_path)
284288
with lock:
@@ -771,14 +775,20 @@ def unlink_shared_memory(self):
771775

772776
if self._shared_save_model_flag is not None:
773777
while self._shared_save_model_flag[0] > 0: # async process is saving
778+
if not self._process_model_weight.is_alive():
779+
raise RuntimeError("The process that saves model_weight has been killed unexpectedly.")
774780
time.sleep(0.5)
775781
self._shared_save_model_flag[0] = -1
776782
if self._shared_save_master_weight_flag is not None:
777783
while self._shared_save_master_weight_flag[0] > 0:
784+
if not self._process_master_weight.is_alive():
785+
raise RuntimeError("The process that saves master_weight has been killed unexpectedly.")
778786
time.sleep(0.5)
779787
self._shared_save_master_weight_flag[0] = -1
780788
if self._shared_save_optimizer_flag is not None:
781789
while self._shared_save_optimizer_flag[0] > 0:
790+
if not self._process_optimizer_weight.is_alive():
791+
raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.")
782792
time.sleep(0.5)
783793
self._shared_save_optimizer_flag[0] = -1
784794

@@ -795,7 +805,8 @@ def unlink_shared_memory(self):
795805
self._shm_optimizer_weight.unlink()
796806
self._shm_optimizer_weight = None
797807

798-
dist.barrier()
808+
if paddle.distributed.get_world_size() > 1:
809+
dist.barrier()
799810

800811

801812
def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False):

0 commit comments

Comments
 (0)