From c825a7b71dd32e237e6c108bc3cf1b8fc1f3b527 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Tue, 18 Jun 2024 11:28:36 +0800 Subject: [PATCH 1/2] Update PaddleNLP to fix PPO --- paddlenlp/trainer/trainer.py | 2 +- paddlenlp/trainer/utils/sharding_io.py | 11 ++++++----- paddlenlp/transformers/model_utils.py | 10 +++++----- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 5b974ad1d63b..01e5fccbc02e 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1795,7 +1795,7 @@ def create_scheduler(self, num_training_steps: int): self.args.warmup_steps if self.args.warmup_steps > 0 else int(self.args.warmup_ratio * num_training_steps) ) decay_steps = num_training_steps - if hasattr(self.args, "decay_steps") and self.args.decay_steps > 0: + if getattr(self.args, "decay_steps", None) and self.args.decay_steps > 0: decay_steps = self.args.decay_steps if self.lr_scheduler is None: diff --git a/paddlenlp/trainer/utils/sharding_io.py b/paddlenlp/trainer/utils/sharding_io.py index f663d04d123a..a4e4ea683f3f 100644 --- a/paddlenlp/trainer/utils/sharding_io.py +++ b/paddlenlp/trainer/utils/sharding_io.py @@ -126,7 +126,7 @@ def __init__(self, args, model, optimizer=None, hcg=None): def set_optimizer(self, optimizer): self.optimizer = optimizer - def load_state_dict_from_checkpoint_with_reshard(self, checkpoint, base_weight_name, model_wrapped): + def load_state_dict_from_checkpoint_with_reshard(self, checkpoint, base_weight_name, model_wrapped, opt_state_dict=None): """load state_dict from_checkpoint with reshard, Only load model state dict. Args: checkpoint (str): The directory of the checkpoint. @@ -180,7 +180,7 @@ def filter_func(name): state_dict = reshard_util.all_gather_state_dict(state_dict, filter_func, self.sharding_group) if self.args.bf16: - state_dict = self._recover_params_from_master_weights(state_dict) + state_dict = self._recover_params_from_master_weights(state_dict, opt_state_dict=opt_state_dict) return state_dict @@ -413,9 +413,10 @@ def _get_distributed_strategy(self): } return parallel_config - def _recover_params_from_master_weights(self, state_dict): - opt_state_dict = self.optimizer.state_dict() - assert "master_weights" in opt_state_dict + def _recover_params_from_master_weights(self, state_dict, opt_state_dict=None): + if opt_state_dict is None: + opt_state_dict = self.optimizer.state_dict() + assert "master_weights" in opt_state_dict, opt_state_dict.keys() master_weights = opt_state_dict["master_weights"] tmp = OrderedDict() (master_weights, tmp) = (tmp, master_weights) diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index 31ecadcabd2b..5bd70536a9d0 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -1687,8 +1687,8 @@ def _load_pretrained_model( model: PretrainedModel, state_dict: Dict[str, Tensor], loaded_keys: List[str], - resolved_archive_file: Union[str, List], - pretrained_model_name_or_path, + resolved_archive_file: Union[str, List] = [], + pretrained_model_name_or_path=None, config=None, ignore_mismatched_sizes=False, low_cpu_mem_usage=False, @@ -1743,7 +1743,7 @@ def _load_pretrained_model( quantization_linear_list = [".".join([prefix, s]) for s in quantization_linear_list] # Weight quantization if not yet quantized & update loaded_keys - if config.quantization_config.is_weight_quantize(): + if hasattr(config, "quantization_config") and config.quantization_config.is_weight_quantize(): try: from ..quantization.quantization_utils import ( convert_to_quantize_state_dict, @@ -1873,7 +1873,7 @@ def _fuse_or_split_keys( state_dict, config, loaded_keys, - pre_tensor_parallel_split=True if config.tensor_parallel_degree > 1 else False, + pre_tensor_parallel_split=True if config is not None and config.tensor_parallel_degree > 1 else False, ) missing_keys = list(set(missing_keys) - set(new_keys)) unexpected_keys = list(set(unexpected_keys) - set(fused_keys)) @@ -1887,7 +1887,7 @@ def _fuse_or_split_keys( ignore_mismatched_sizes, ) - if config.quantization_config.is_weight_quantize(): + if hasattr(config, "quantization_config") and config.quantization_config.is_weight_quantize(): error_msgs = _load_state_dict_into_meta_model( model_to_load, state_dict, From f5fb8c21def5b3f724d1e9bb8e1a92206f8b5fb8 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Tue, 18 Jun 2024 11:35:12 +0800 Subject: [PATCH 2/2] fix lint --- paddlenlp/trainer/utils/sharding_io.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddlenlp/trainer/utils/sharding_io.py b/paddlenlp/trainer/utils/sharding_io.py index a4e4ea683f3f..2d3d34c82d28 100644 --- a/paddlenlp/trainer/utils/sharding_io.py +++ b/paddlenlp/trainer/utils/sharding_io.py @@ -126,7 +126,9 @@ def __init__(self, args, model, optimizer=None, hcg=None): def set_optimizer(self, optimizer): self.optimizer = optimizer - def load_state_dict_from_checkpoint_with_reshard(self, checkpoint, base_weight_name, model_wrapped, opt_state_dict=None): + def load_state_dict_from_checkpoint_with_reshard( + self, checkpoint, base_weight_name, model_wrapped, opt_state_dict=None + ): """load state_dict from_checkpoint with reshard, Only load model state dict. Args: checkpoint (str): The directory of the checkpoint.