Skip to content

Commit 3cd145c

Browse files
authored
[Unified Checkpoint] update async save logic (#9274) (#9275)
* update async save signal * fix async save hang
1 parent 7084196 commit 3cd145c

File tree

4 files changed

+103
-34
lines changed

4 files changed

+103
-34
lines changed

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,21 +136,25 @@ def __init__(self, args):
136136
self._process_master_weight = None
137137
self._process_optimizer_weight = None
138138
self._lock = None
139-
self._shared_save_path = None
140139
self._shared_save_model_flag = None
141140
self._shared_save_master_weight_flag = None
142141
self._shared_save_optimizer_flag = None
143142

144143
if "async_save" in self.args.unified_checkpoint_config:
145144
self._lock = multiprocessing.Lock()
146145
self._shared_save_model_path = multiprocessing.Array("c", 100000)
146+
self._shared_save_model_signal_path = multiprocessing.Array("c", 100000)
147147
self._shared_save_master_weight_path = multiprocessing.Array("c", 100000)
148+
self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000)
148149
self._shared_save_optimizer_path = multiprocessing.Array("c", 100000)
150+
self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000)
149151
self._shared_save_model_flag = multiprocessing.Array("i", 1)
150152
self._shared_save_master_weight_flag = multiprocessing.Array("i", 1)
151153
self._shared_save_optimizer_flag = multiprocessing.Array("i", 1)
152154

153-
def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_type="model_weight"):
155+
def _file_save_async_or_sync(
156+
self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight"
157+
):
154158
if is_sync:
155159
for k in list(state_dict.keys()):
156160
if isinstance(state_dict[k], paddle.Tensor):
@@ -165,6 +169,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
165169
meta_dict = self._meta_dict_model
166170
shared_save_flag = self._shared_save_model_flag
167171
shared_save_path = self._shared_save_model_path
172+
shared_save_signal_path = self._shared_save_model_signal_path
168173
if self._process_model_weight is None:
169174
self._process_model_weight = multiprocessing.Process(
170175
target=self._save_file_async_in_process,
@@ -173,12 +178,14 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
173178
self._shm_model_weight.name,
174179
self._shared_save_model_flag,
175180
self._shared_save_model_path,
181+
self._shared_save_model_signal_path,
176182
self._lock,
177183
state_dict_type,
178184
self.global_rank,
179185
),
180186
)
181187
self._process_model_weight.start()
188+
process = self._process_model_weight
182189
elif state_dict_type == "master_weight":
183190
if self._shm_master_weight is None:
184191
self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict)
@@ -187,6 +194,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
187194
meta_dict = self._meta_dict_master_weight
188195
shared_save_flag = self._shared_save_master_weight_flag
189196
shared_save_path = self._shared_save_master_weight_path
197+
shared_save_signal_path = self._shared_save_master_weight_signal_path
190198
if self._process_master_weight is None:
191199
self._process_master_weight = multiprocessing.Process(
192200
target=self._save_file_async_in_process,
@@ -195,6 +203,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
195203
self._shm_master_weight.name,
196204
self._shared_save_master_weight_flag,
197205
self._shared_save_master_weight_path,
206+
self._shared_save_master_weight_signal_path,
198207
self._lock,
199208
"model_weight"
200209
if "skip_save_model_weight" in self.args.unified_checkpoint_config
@@ -203,6 +212,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
203212
),
204213
)
205214
self._process_master_weight.start()
215+
process = self._process_master_weight
206216
elif state_dict_type == "optimizer_weight":
207217
if self._shm_optimizer_weight is None:
208218
self._meta_dict_optim, buffer_size = create_meta_dict(state_dict)
@@ -211,6 +221,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
211221
meta_dict = self._meta_dict_optim
212222
shared_save_flag = self._shared_save_optimizer_flag
213223
shared_save_path = self._shared_save_optimizer_path
224+
shared_save_signal_path = self._shared_save_optimizer_signal_path
214225
if self._process_optimizer_weight is None:
215226
self._process_optimizer_weight = multiprocessing.Process(
216227
target=self._save_file_async_in_process,
@@ -219,21 +230,26 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
219230
self._shm_optimizer_weight.name,
220231
self._shared_save_optimizer_flag,
221232
self._shared_save_optimizer_path,
233+
self._shared_save_optimizer_signal_path,
222234
self._lock,
223235
state_dict_type,
224236
self.global_rank,
225237
),
226238
)
227239
self._process_optimizer_weight.start()
240+
process = self._process_optimizer_weight
228241

229242
while True: # wait until no process is saving.
230243
flag_value = shared_save_flag[0]
231244
if flag_value == 0:
232245
break
246+
if not process.is_alive():
247+
raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.")
233248
time.sleep(0.5)
234249
logger.info(f"Wait for the previous save process to finish saving {state_dict_type}")
235250
# only save model weight or save master weight, we enter this loop.
236251
self._reset_and_update(shared_save_path, path)
252+
self._reset_and_update(shared_save_signal_path, signal_path)
237253
_traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf)
238254
with self._lock:
239255
shared_save_flag[0] = 1
@@ -244,6 +260,7 @@ def _save_file_async_in_process(
244260
shm_name,
245261
shared_save_flag,
246262
shared_save_path,
263+
shared_save_signal_path,
247264
lock,
248265
state_dict_type,
249266
global_rank,
@@ -257,11 +274,12 @@ def _save_file_async_in_process(
257274
continue
258275
if flag_value == 1: # need to save
259276
path = shared_save_path[:].decode("utf-8").rstrip("\x00")
277+
signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00")
260278
logger.info(f"Start to async save {path}")
261279
state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array
262280
safe_save_file(state_dict, path, {"format": "np"})
263281
del state_dict
264-
saved_signal_path = os.path.join(os.path.dirname(path), f".{state_dict_type}.done.{global_rank}")
282+
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")
265283
paddle.save(global_rank, saved_signal_path)
266284
with lock:
267285
shared_save_flag[0] = 0
@@ -276,7 +294,7 @@ def _reset_and_update(self, shared_array, new_value):
276294
encoded_value = new_value.encode("utf-8")
277295
shared_array[: len(encoded_value)] = encoded_value
278296

279-
def save_unified_checkpoint(self, model, optimizer, output_dir):
297+
def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None):
280298
"""save unified checkpoint
281299
282300
Args:
@@ -313,6 +331,8 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
313331

314332
save_directory = output_dir
315333
os.makedirs(save_directory, exist_ok=True)
334+
if signal_dir is not None:
335+
os.makedirs(signal_dir, exist_ok=True) # only for async save
316336

317337
# save model weights
318338
if not skip_save_model_weight:
@@ -325,6 +345,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
325345
self._file_save_async_or_sync(
326346
state_dict,
327347
path=os.path.join(save_directory, shard_file),
348+
signal_path=signal_dir,
328349
is_sync=is_sync_save,
329350
state_dict_type="model_weight",
330351
)
@@ -393,7 +414,7 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str)
393414
if self.args.dataset_rank == 0:
394415
load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True)
395416

396-
def save_non_merge_optimizer(self, model, optimizer, output_dir):
417+
def save_non_merge_optimizer(self, model, optimizer, output_dir, signal_dir):
397418
paddle.device.cuda.empty_cache()
398419
optim_state_dict = nested_copy(optimizer.state_dict())
399420
master_weights = None
@@ -432,12 +453,14 @@ def save_non_merge_optimizer(self, model, optimizer, output_dir):
432453
self._file_save_async_or_sync(
433454
optim_state_dict,
434455
path=os.path.join(output_dir, optimizer_name),
456+
signal_path=signal_dir,
435457
is_sync=is_sync_save,
436458
state_dict_type="optimizer_weight",
437459
)
438460
self._file_save_async_or_sync(
439461
master_weights,
440462
path=os.path.join(output_dir, master_weights_name),
463+
signal_path=signal_dir,
441464
is_sync=is_sync_save,
442465
state_dict_type="master_weight",
443466
)
@@ -484,22 +507,23 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint):
484507

485508
return returned_optim_state_dict
486509

487-
def save_unified_optimizer(self, model, optimizer, output_dir):
510+
def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
488511
"""save unified optimizer
489512
490513
Args:
491514
model (PretrainedModel): model used to get key mapping.
492515
optimizer (Optimizer): optimizer to save
493516
output_dir (str): Save directory.
517+
signal_dir (str): Asynchronous saving signal directory.
494518
495519
"""
496520

497521
if "ignore_merge_optimizer" in self.args.unified_checkpoint_config:
498-
self.save_non_merge_optimizer(model, optimizer, output_dir)
522+
self.save_non_merge_optimizer(model, optimizer, output_dir, signal_dir)
499523
return
500524

501525
if paddle.distributed.get_world_size() <= 1:
502-
self.save_single_card_optimizer(model, optimizer, output_dir)
526+
self.save_single_card_optimizer(model, optimizer, output_dir) # no need to save signal
503527
return
504528

505529
# Split into naive optimizer params and master weights.
@@ -515,20 +539,24 @@ def save_unified_optimizer(self, model, optimizer, output_dir):
515539

516540
save_directory = output_dir
517541
os.makedirs(save_directory, exist_ok=True)
542+
if signal_dir is not None:
543+
os.makedirs(signal_dir, exist_ok=True)
518544

519545
is_sync_save = True
520546
if "async_save" in self.args.unified_checkpoint_config:
521547
is_sync_save = False
522548
self._file_save_async_or_sync(
523549
optim_state_dict,
524550
path=os.path.join(save_directory, shard_optim_file),
551+
signal_path=signal_dir,
525552
is_sync=is_sync_save,
526553
state_dict_type="optimizer_weight",
527554
)
528555
if master_weight_state_dict is not None:
529556
self._file_save_async_or_sync(
530557
master_weight_state_dict,
531558
path=os.path.join(save_directory, shard_master_weight_file),
559+
signal_path=signal_dir,
532560
is_sync=is_sync_save,
533561
state_dict_type="master_weight",
534562
)
@@ -716,14 +744,20 @@ def unlink_shared_memory(self):
716744

717745
if self._shared_save_model_flag is not None:
718746
while self._shared_save_model_flag[0] > 0: # async process is saving
747+
if not self._process_model_weight.is_alive():
748+
raise RuntimeError("The process that saves model_weight has been killed unexpectedly.")
719749
time.sleep(0.5)
720750
self._shared_save_model_flag[0] = -1
721751
if self._shared_save_master_weight_flag is not None:
722752
while self._shared_save_master_weight_flag[0] > 0:
753+
if not self._process_master_weight.is_alive():
754+
raise RuntimeError("The process that saves master_weight has been killed unexpectedly.")
723755
time.sleep(0.5)
724756
self._shared_save_master_weight_flag[0] = -1
725757
if self._shared_save_optimizer_flag is not None:
726758
while self._shared_save_optimizer_flag[0] > 0:
759+
if not self._process_optimizer_weight.is_alive():
760+
raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.")
727761
time.sleep(0.5)
728762
self._shared_save_optimizer_flag[0] = -1
729763

@@ -740,7 +774,8 @@ def unlink_shared_memory(self):
740774
self._shm_optimizer_weight.unlink()
741775
self._shm_optimizer_weight = None
742776

743-
dist.barrier()
777+
if paddle.distributed.get_world_size() > 1:
778+
dist.barrier()
744779

745780

746781
def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False):

0 commit comments

Comments
 (0)