From 660d88755c9f6317384ba8e33cc209ae1365c317 Mon Sep 17 00:00:00 2001 From: ShenLiang <2282912238@qq.com> Date: Thu, 23 May 2024 17:09:36 +0800 Subject: [PATCH 1/6] fix bug of sharding format (#8483) --- paddlenlp/trainer/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 19faf04cc591..813ec9bc45ab 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1671,7 +1671,7 @@ def pipeline_parallel_rank(self): return 0 def _format_name(self, prefix, rank, degree): - size = max(2, len(str(degree))) + size = 2 return f"{prefix}{rank:0>{size}d}" @property From 09d4abd6278b3d91f73476becc81bfd74cebfa9d Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 3 Jun 2024 17:58:39 +0800 Subject: [PATCH 2/6] Optimize the speed of set_state_dict (#8532) --- paddlenlp/trainer/trainer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 14a0c6d6983c..da7d71cc85ae 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -564,7 +564,12 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): base_weight_name=weight_name, model_wrapped=self.model_wrapped, ) - self.model.set_state_dict(state_dict) + old_state_dict = self.model.state_dict() + new_state_dict = {} + for k, v in state_dict.items(): + if k not in old_state_dict or id(v) != id(old_state_dict[k]): + new_state_dict[k] = v + self.model.set_state_dict(new_state_dict) else: if resume_from_checkpoint is not None and (self.args.dataset_rank == 0 or self.args.use_expert_parallel): From e90cc2aaed0d37edcf0ec21ee415099cf5fc0b5c Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 3 Jun 2024 22:19:12 +0800 Subject: [PATCH 3/6] fix sharding reshard save (#8535) --- paddlenlp/trainer/trainer.py | 22 +++++++++++++++------- paddlenlp/trainer/utils/sharding_io.py | 9 +++++---- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index da7d71cc85ae..213cb90cd7b9 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -889,7 +889,8 @@ def _inner_training_loop( npu_accelerate_plugin(self.optimizer) - self.timers and self.timers("read-data").start() + if not self.args.ignore_data_skip: + self.timers and self.timers("read-data").start() for epoch in range(epochs_trained, num_train_epochs): if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance( @@ -903,7 +904,8 @@ def _inner_training_loop( for step, inputs in enumerate(epoch_iterator): if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1: inputs = split_inputs_sequence_dim(inputs) - self.timers and self.timers("read-data").stop() + if not self.args.ignore_data_skip: + self.timers and self.timers("read-data").stop() os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step) self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs) @@ -1087,7 +1089,9 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): if self.control.should_epoch_stop or self.control.should_training_stop: break - self.timers and self.timers("read-data").start() + + if not self.args.ignore_data_skip: + self.timers and self.timers("read-data").start() if step < 0: logger.warning( @@ -2447,10 +2451,14 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ if state_dict is None: state_dict = self.model.state_dict() - self._save_ckpt_func( - state_dict, - os.path.join(output_dir, _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix)), - ) + if self.args.should_save_sharding_stage1_model: + state_dict, _, _ = self.sharding_io.manipulate_state_dict_and_config( + unwrap_model(self.model), merge_tensor_parallel=False, state_dict=state_dict) + variant = _add_variant(PADDLE_WEIGHTS_NAME, self.args.sharded_name_suffix()) + else: + variant = _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix) + + self._save_ckpt_func(state_dict, os.path.join(output_dir, variant)) else: if isinstance(self.model, PretrainedModel) and self.args.should_save_sharding_stage1_model: config_to_save = None diff --git a/paddlenlp/trainer/utils/sharding_io.py b/paddlenlp/trainer/utils/sharding_io.py index 4fe55d175005..1f7cd4eec2e7 100644 --- a/paddlenlp/trainer/utils/sharding_io.py +++ b/paddlenlp/trainer/utils/sharding_io.py @@ -321,12 +321,13 @@ def reshard_sharding(node_model_state): node_model_state = reshard_pp(node_model_state) return reshard_sharding(node_model_state) - def manipulate_state_dict_and_config(self, model_to_save, merge_tensor_parallel=False): + def manipulate_state_dict_and_config(self, model_to_save, merge_tensor_parallel=False, state_dict=None): weight_name_suffix = self.args.sharded_name_suffix() - state_dict = model_to_save.state_dict() - if self.args.should_save_sharding_stage1_model: - state_dict = filter_sharded_params(state_dict, self.optimizer, self.sharding_group) + if state_dict is None: + state_dict = model_to_save.state_dict() + if self.args.should_save_sharding_stage1_model: + state_dict = filter_sharded_params(state_dict, self.optimizer, self.sharding_group) config_to_save = None merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel From db4e41b4785772d38726454fb61ac327af517a69 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 3 Jun 2024 22:28:58 +0800 Subject: [PATCH 4/6] Fix ignore_data_skip bug when timer is enabled (#8536) --- paddlenlp/trainer/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 213cb90cd7b9..c308e8a7539d 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -889,7 +889,7 @@ def _inner_training_loop( npu_accelerate_plugin(self.optimizer) - if not self.args.ignore_data_skip: + if self.args.ignore_data_skip: self.timers and self.timers("read-data").start() for epoch in range(epochs_trained, num_train_epochs): @@ -904,7 +904,7 @@ def _inner_training_loop( for step, inputs in enumerate(epoch_iterator): if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1: inputs = split_inputs_sequence_dim(inputs) - if not self.args.ignore_data_skip: + if self.args.ignore_data_skip: self.timers and self.timers("read-data").stop() os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step) self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs) @@ -1090,7 +1090,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): if self.control.should_epoch_stop or self.control.should_training_stop: break - if not self.args.ignore_data_skip: + if self.args.ignore_data_skip: self.timers and self.timers("read-data").start() if step < 0: From f57457ebcd9898ee4dec1e76e6cda3ccaa943385 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 5 Jun 2024 11:32:28 +0800 Subject: [PATCH 5/6] Save parameter shape and dtype when using sharding reshard (#8543) * save parameter shape and dtype * refactor --- paddlenlp/trainer/utils/sharding_io.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/paddlenlp/trainer/utils/sharding_io.py b/paddlenlp/trainer/utils/sharding_io.py index 1f7cd4eec2e7..f663d04d123a 100644 --- a/paddlenlp/trainer/utils/sharding_io.py +++ b/paddlenlp/trainer/utils/sharding_io.py @@ -385,7 +385,7 @@ def save_distributed_model_meta(self, dir): path = os.path.join(dir, MODEL_META_NAME) with open(path, "w") as f: - json.dump(model_meta, f, indent=4) + json.dump(model_meta, f) def _get_distributed_strategy(self): pp_degree = 1 @@ -545,13 +545,18 @@ def _gather_sharding_metas(self): pp_overlap = unwrap_optimizer(self.optimizer, DygraphShardingOptimizerV2).pp_overlap model = self.model - structure_name_mapping = {k: v.name for (k, v) in model.state_dict().items()} + structure_name_mapping = {} + param_meta = {} + for k, v in model.state_dict().items(): + structure_name_mapping[k] = v.name + param_meta[k] = (v.shape, int(v.dtype)) sharding_metas = {} sharding_meta = {} sharding_meta["param2rank"] = param2rank sharding_meta["structure_name_mapping"] = structure_name_mapping + sharding_meta["param_meta"] = param_meta sharding_meta["sharding_strategy"] = sharding_strategy sharding_meta["enable_overlap"] = pp_overlap suffix = f"tp{self.args.tensor_parallel_rank:0>2d}_pp{self.args.pipeline_parallel_rank:0>2d}" From 91635e69609b1116967b0e2d999a63aac4d50f10 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 5 Jun 2024 12:13:52 +0800 Subject: [PATCH 6/6] format pre-commit --- paddlenlp/trainer/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index c308e8a7539d..01ca9165c345 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2453,7 +2453,8 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ if self.args.should_save_sharding_stage1_model: state_dict, _, _ = self.sharding_io.manipulate_state_dict_and_config( - unwrap_model(self.model), merge_tensor_parallel=False, state_dict=state_dict) + unwrap_model(self.model), merge_tensor_parallel=False, state_dict=state_dict + ) variant = _add_variant(PADDLE_WEIGHTS_NAME, self.args.sharded_name_suffix()) else: variant = _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix)