Skip to content

Commit bcacc6a

Browse files
sneaxiyForFishes
andauthored
[Cherry pick] Sharding reshard function enhancement (#8544)
* fix bug of sharding format (#8483) * Optimize the speed of set_state_dict (#8532) * fix sharding reshard save (#8535) * Fix ignore_data_skip bug when timer is enabled (#8536) * Save parameter shape and dtype when using sharding reshard (#8543) * save parameter shape and dtype * refactor * format pre-commit --------- Co-authored-by: ShenLiang <2282912238@qq.com>
1 parent 162d8d3 commit bcacc6a

File tree

2 files changed

+35
-14
lines changed

2 files changed

+35
-14
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,12 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
569569
base_weight_name=weight_name,
570570
model_wrapped=self.model_wrapped,
571571
)
572-
self.model.set_state_dict(state_dict)
572+
old_state_dict = self.model.state_dict()
573+
new_state_dict = {}
574+
for k, v in state_dict.items():
575+
if k not in old_state_dict or id(v) != id(old_state_dict[k]):
576+
new_state_dict[k] = v
577+
self.model.set_state_dict(new_state_dict)
573578
else:
574579
if resume_from_checkpoint is not None and (self.args.dataset_rank == 0 or self.args.use_expert_parallel):
575580

@@ -891,7 +896,8 @@ def _inner_training_loop(
891896

892897
npu_accelerate_plugin(self.optimizer)
893898

894-
self.timers and self.timers("read-data").start()
899+
if self.args.ignore_data_skip:
900+
self.timers and self.timers("read-data").start()
895901

896902
for epoch in range(epochs_trained, num_train_epochs):
897903
if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance(
@@ -907,7 +913,9 @@ def _inner_training_loop(
907913
inputs = split_inputs_sequence_dim(inputs)
908914
if self.args.use_hybrid_parallel and self.args.context_parallel_degree > 1:
909915
inputs = split_inputs_sequence_dim_load_balance(inputs)
910-
self.timers and self.timers("read-data").stop()
916+
if self.args.ignore_data_skip:
917+
self.timers and self.timers("read-data").stop()
918+
911919
os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step)
912920
self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs)
913921

@@ -1098,7 +1106,9 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
10981106

10991107
if self.control.should_epoch_stop or self.control.should_training_stop:
11001108
break
1101-
self.timers and self.timers("read-data").start()
1109+
1110+
if self.args.ignore_data_skip:
1111+
self.timers and self.timers("read-data").start()
11021112

11031113
if step < 0:
11041114
logger.warning(
@@ -2462,10 +2472,15 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
24622472
if state_dict is None:
24632473
state_dict = self.model.state_dict()
24642474

2465-
self._save_ckpt_func(
2466-
state_dict,
2467-
os.path.join(output_dir, _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix)),
2468-
)
2475+
if self.args.should_save_sharding_stage1_model:
2476+
state_dict, _, _ = self.sharding_io.manipulate_state_dict_and_config(
2477+
unwrap_model(self.model), merge_tensor_parallel=False, state_dict=state_dict
2478+
)
2479+
variant = _add_variant(PADDLE_WEIGHTS_NAME, self.args.sharded_name_suffix())
2480+
else:
2481+
variant = _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix)
2482+
2483+
self._save_ckpt_func(state_dict, os.path.join(output_dir, variant))
24692484
else:
24702485
if isinstance(self.model, PretrainedModel) and self.args.should_save_sharding_stage1_model:
24712486
config_to_save = None

paddlenlp/trainer/utils/sharding_io.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -321,12 +321,13 @@ def reshard_sharding(node_model_state):
321321
node_model_state = reshard_pp(node_model_state)
322322
return reshard_sharding(node_model_state)
323323

324-
def manipulate_state_dict_and_config(self, model_to_save, merge_tensor_parallel=False):
324+
def manipulate_state_dict_and_config(self, model_to_save, merge_tensor_parallel=False, state_dict=None):
325325
weight_name_suffix = self.args.sharded_name_suffix()
326326

327-
state_dict = model_to_save.state_dict()
328-
if self.args.should_save_sharding_stage1_model:
329-
state_dict = filter_sharded_params(state_dict, self.optimizer, self.sharding_group)
327+
if state_dict is None:
328+
state_dict = model_to_save.state_dict()
329+
if self.args.should_save_sharding_stage1_model:
330+
state_dict = filter_sharded_params(state_dict, self.optimizer, self.sharding_group)
330331

331332
config_to_save = None
332333
merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel
@@ -384,7 +385,7 @@ def save_distributed_model_meta(self, dir):
384385

385386
path = os.path.join(dir, MODEL_META_NAME)
386387
with open(path, "w") as f:
387-
json.dump(model_meta, f, indent=4)
388+
json.dump(model_meta, f)
388389

389390
def _get_distributed_strategy(self):
390391
pp_degree = 1
@@ -544,13 +545,18 @@ def _gather_sharding_metas(self):
544545
pp_overlap = unwrap_optimizer(self.optimizer, DygraphShardingOptimizerV2).pp_overlap
545546

546547
model = self.model
547-
structure_name_mapping = {k: v.name for (k, v) in model.state_dict().items()}
548+
structure_name_mapping = {}
549+
param_meta = {}
550+
for k, v in model.state_dict().items():
551+
structure_name_mapping[k] = v.name
552+
param_meta[k] = (v.shape, int(v.dtype))
548553

549554
sharding_metas = {}
550555
sharding_meta = {}
551556

552557
sharding_meta["param2rank"] = param2rank
553558
sharding_meta["structure_name_mapping"] = structure_name_mapping
559+
sharding_meta["param_meta"] = param_meta
554560
sharding_meta["sharding_strategy"] = sharding_strategy
555561
sharding_meta["enable_overlap"] = pp_overlap
556562
suffix = f"tp{self.args.tensor_parallel_rank:0>2d}_pp{self.args.pipeline_parallel_rank:0>2d}"

0 commit comments

Comments
 (0)