Skip to content

Commit 2b62766

Browse files
committed
update async save signal
1 parent 0e96b0f commit 2b62766

File tree

4 files changed

+91
-33
lines changed

4 files changed

+91
-33
lines changed

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,21 +140,25 @@ def __init__(self, args):
140140
self._process_master_weight = None
141141
self._process_optimizer_weight = None
142142
self._lock = None
143-
self._shared_save_path = None
144143
self._shared_save_model_flag = None
145144
self._shared_save_master_weight_flag = None
146145
self._shared_save_optimizer_flag = None
147146

148147
if "async_save" in self.args.unified_checkpoint_config:
149148
self._lock = multiprocessing.Lock()
150149
self._shared_save_model_path = multiprocessing.Array("c", 100000)
150+
self._shared_save_model_signal_path = multiprocessing.Array("c", 100000)
151151
self._shared_save_master_weight_path = multiprocessing.Array("c", 100000)
152+
self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000)
152153
self._shared_save_optimizer_path = multiprocessing.Array("c", 100000)
154+
self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000)
153155
self._shared_save_model_flag = multiprocessing.Array("i", 1)
154156
self._shared_save_master_weight_flag = multiprocessing.Array("i", 1)
155157
self._shared_save_optimizer_flag = multiprocessing.Array("i", 1)
156158

157-
def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_type="model_weight"):
159+
def _file_save_async_or_sync(
160+
self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight"
161+
):
158162
if is_sync:
159163
for k in list(state_dict.keys()):
160164
if isinstance(state_dict[k], paddle.Tensor):
@@ -169,6 +173,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
169173
meta_dict = self._meta_dict_model
170174
shared_save_flag = self._shared_save_model_flag
171175
shared_save_path = self._shared_save_model_path
176+
shared_save_signal_path = self._shared_save_model_signal_path
172177
if self._process_model_weight is None:
173178
self._process_model_weight = multiprocessing.Process(
174179
target=self._save_file_async_in_process,
@@ -177,6 +182,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
177182
self._shm_model_weight.name,
178183
self._shared_save_model_flag,
179184
self._shared_save_model_path,
185+
self._shared_save_model_signal_path,
180186
self._lock,
181187
state_dict_type,
182188
self.global_rank,
@@ -191,6 +197,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
191197
meta_dict = self._meta_dict_master_weight
192198
shared_save_flag = self._shared_save_master_weight_flag
193199
shared_save_path = self._shared_save_master_weight_path
200+
shared_save_signal_path = self._shared_save_master_weight_signal_path
194201
if self._process_master_weight is None:
195202
self._process_master_weight = multiprocessing.Process(
196203
target=self._save_file_async_in_process,
@@ -199,6 +206,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
199206
self._shm_master_weight.name,
200207
self._shared_save_master_weight_flag,
201208
self._shared_save_master_weight_path,
209+
self._shared_save_master_weight_signal_path,
202210
self._lock,
203211
"model_weight"
204212
if "skip_save_model_weight" in self.args.unified_checkpoint_config
@@ -215,6 +223,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
215223
meta_dict = self._meta_dict_optim
216224
shared_save_flag = self._shared_save_optimizer_flag
217225
shared_save_path = self._shared_save_optimizer_path
226+
shared_save_signal_path = self._shared_save_optimizer_signal_path
218227
if self._process_optimizer_weight is None:
219228
self._process_optimizer_weight = multiprocessing.Process(
220229
target=self._save_file_async_in_process,
@@ -223,6 +232,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
223232
self._shm_optimizer_weight.name,
224233
self._shared_save_optimizer_flag,
225234
self._shared_save_optimizer_path,
235+
self._shared_save_optimizer_signal_path,
226236
self._lock,
227237
state_dict_type,
228238
self.global_rank,
@@ -238,6 +248,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
238248
logger.info(f"Wait for the previous save process to finish saving {state_dict_type}")
239249
# only save model weight or save master weight, we enter this loop.
240250
self._reset_and_update(shared_save_path, path)
251+
self._reset_and_update(shared_save_signal_path, signal_path)
241252
_traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf)
242253
with self._lock:
243254
shared_save_flag[0] = 1
@@ -248,6 +259,7 @@ def _save_file_async_in_process(
248259
shm_name,
249260
shared_save_flag,
250261
shared_save_path,
262+
shared_save_signal_path,
251263
lock,
252264
state_dict_type,
253265
global_rank,
@@ -261,11 +273,13 @@ def _save_file_async_in_process(
261273
continue
262274
if flag_value == 1: # need to save
263275
path = shared_save_path[:].decode("utf-8").rstrip("\x00")
276+
signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00")
264277
logger.info(f"Start to async save {path}")
265278
state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array
266279
safe_save_file(state_dict, path, {"format": "np"})
267280
del state_dict
268-
saved_signal_path = os.path.join(os.path.dirname(path), f".{state_dict_type}.done.{global_rank}")
281+
os.makedirs(signal_path, exist_ok=True)
282+
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")
269283
paddle.save(global_rank, saved_signal_path)
270284
with lock:
271285
shared_save_flag[0] = 0
@@ -280,7 +294,7 @@ def _reset_and_update(self, shared_array, new_value):
280294
encoded_value = new_value.encode("utf-8")
281295
shared_array[: len(encoded_value)] = encoded_value
282296

283-
def save_unified_checkpoint(self, model, optimizer, output_dir):
297+
def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None):
284298
"""save unified checkpoint
285299
286300
Args:
@@ -317,6 +331,8 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
317331

318332
save_directory = output_dir
319333
os.makedirs(save_directory, exist_ok=True)
334+
if signal_dir is not None:
335+
os.makedirs(signal_dir, exist_ok=True) # only for async save
320336

321337
# save model weights
322338
if not skip_save_model_weight:
@@ -329,6 +345,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
329345
self._file_save_async_or_sync(
330346
state_dict,
331347
path=os.path.join(save_directory, shard_file),
348+
signal_path=signal_dir,
332349
is_sync=is_sync_save,
333350
state_dict_type="model_weight",
334351
)
@@ -397,7 +414,7 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str)
397414
if self.args.dataset_rank == 0 or self.args.use_expert_parallel:
398415
load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True)
399416

400-
def save_non_merge_optimizer(self, model, optimizer, output_dir):
417+
def save_non_merge_optimizer(self, model, optimizer, output_dir, signal_dir):
401418
paddle.device.cuda.empty_cache()
402419
optim_state_dict = nested_copy(optimizer.state_dict())
403420
master_weights = None
@@ -456,12 +473,14 @@ def save_non_merge_optimizer(self, model, optimizer, output_dir):
456473
self._file_save_async_or_sync(
457474
optim_state_dict,
458475
path=os.path.join(output_dir, optimizer_name),
476+
signal_path=signal_dir,
459477
is_sync=is_sync_save,
460478
state_dict_type="optimizer_weight",
461479
)
462480
self._file_save_async_or_sync(
463481
master_weights,
464482
path=os.path.join(output_dir, master_weights_name),
483+
signal_path=signal_dir,
465484
is_sync=is_sync_save,
466485
state_dict_type="master_weight",
467486
)
@@ -511,22 +530,23 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint):
511530

512531
return returned_optim_state_dict
513532

514-
def save_unified_optimizer(self, model, optimizer, output_dir):
533+
def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
515534
"""save unified optimizer
516535
517536
Args:
518537
model (PretrainedModel): model used to get key mapping.
519538
optimizer (Optimizer): optimizer to save
520539
output_dir (str): Save directory.
540+
signal_dir (str): Asynchronous saving signal directory.
521541
522542
"""
523543

524544
if "ignore_merge_optimizer" in self.args.unified_checkpoint_config:
525-
self.save_non_merge_optimizer(model, optimizer, output_dir)
545+
self.save_non_merge_optimizer(model, optimizer, output_dir, signal_dir)
526546
return
527547

528548
if paddle.distributed.get_world_size() <= 1:
529-
self.save_single_card_optimizer(model, optimizer, output_dir)
549+
self.save_single_card_optimizer(model, optimizer, output_dir) # no need to save signal
530550
return
531551

532552
# Split into naive optimizer params and master weights.
@@ -542,20 +562,24 @@ def save_unified_optimizer(self, model, optimizer, output_dir):
542562

543563
save_directory = output_dir
544564
os.makedirs(save_directory, exist_ok=True)
565+
if signal_dir is not None:
566+
os.makedirs(signal_dir, exist_ok=True)
545567

546568
is_sync_save = True
547569
if "async_save" in self.args.unified_checkpoint_config:
548570
is_sync_save = False
549571
self._file_save_async_or_sync(
550572
optim_state_dict,
551573
path=os.path.join(save_directory, shard_optim_file),
574+
signal_path=signal_dir,
552575
is_sync=is_sync_save,
553576
state_dict_type="optimizer_weight",
554577
)
555578
if master_weight_state_dict is not None:
556579
self._file_save_async_or_sync(
557580
master_weight_state_dict,
558581
path=os.path.join(save_directory, shard_master_weight_file),
582+
signal_path=signal_dir,
559583
is_sync=is_sync_save,
560584
state_dict_type="master_weight",
561585
)

0 commit comments

Comments
 (0)