Skip to content

Update PaddleNLP to fix PPO #8618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,7 +1795,7 @@ def create_scheduler(self, num_training_steps: int):
self.args.warmup_steps if self.args.warmup_steps > 0 else int(self.args.warmup_ratio * num_training_steps)
)
decay_steps = num_training_steps
if hasattr(self.args, "decay_steps") and self.args.decay_steps > 0:
if getattr(self.args, "decay_steps", None) and self.args.decay_steps > 0:
decay_steps = self.args.decay_steps

if self.lr_scheduler is None:
Expand Down
13 changes: 8 additions & 5 deletions paddlenlp/trainer/utils/sharding_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@
def set_optimizer(self, optimizer):
self.optimizer = optimizer

def load_state_dict_from_checkpoint_with_reshard(self, checkpoint, base_weight_name, model_wrapped):
def load_state_dict_from_checkpoint_with_reshard(
self, checkpoint, base_weight_name, model_wrapped, opt_state_dict=None
):
"""load state_dict from_checkpoint with reshard, Only load model state dict.
Args:
checkpoint (str): The directory of the checkpoint.
Expand Down Expand Up @@ -180,7 +182,7 @@
state_dict = reshard_util.all_gather_state_dict(state_dict, filter_func, self.sharding_group)

if self.args.bf16:
state_dict = self._recover_params_from_master_weights(state_dict)
state_dict = self._recover_params_from_master_weights(state_dict, opt_state_dict=opt_state_dict)

Check warning on line 185 in paddlenlp/trainer/utils/sharding_io.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/sharding_io.py#L185

Added line #L185 was not covered by tests

return state_dict

Expand Down Expand Up @@ -413,9 +415,10 @@
}
return parallel_config

def _recover_params_from_master_weights(self, state_dict):
opt_state_dict = self.optimizer.state_dict()
assert "master_weights" in opt_state_dict
def _recover_params_from_master_weights(self, state_dict, opt_state_dict=None):
if opt_state_dict is None:
opt_state_dict = self.optimizer.state_dict()
assert "master_weights" in opt_state_dict, opt_state_dict.keys()

Check warning on line 421 in paddlenlp/trainer/utils/sharding_io.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/sharding_io.py#L419-L421

Added lines #L419 - L421 were not covered by tests
master_weights = opt_state_dict["master_weights"]
tmp = OrderedDict()
(master_weights, tmp) = (tmp, master_weights)
Expand Down
10 changes: 5 additions & 5 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,8 +1687,8 @@ def _load_pretrained_model(
model: PretrainedModel,
state_dict: Dict[str, Tensor],
loaded_keys: List[str],
resolved_archive_file: Union[str, List],
pretrained_model_name_or_path,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是为什么需要改,需要适配什么情况?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原本是没有这些参数的,这个主要是为了兼容老的写法。

resolved_archive_file: Union[str, List] = [],
pretrained_model_name_or_path=None,
config=None,
ignore_mismatched_sizes=False,
low_cpu_mem_usage=False,
Expand Down Expand Up @@ -1743,7 +1743,7 @@ def _load_pretrained_model(
quantization_linear_list = [".".join([prefix, s]) for s in quantization_linear_list]

# Weight quantization if not yet quantized & update loaded_keys
if config.quantization_config.is_weight_quantize():
if hasattr(config, "quantization_config") and config.quantization_config.is_weight_quantize():
try:
from ..quantization.quantization_utils import (
convert_to_quantize_state_dict,
Expand Down Expand Up @@ -1873,7 +1873,7 @@ def _fuse_or_split_keys(
state_dict,
config,
loaded_keys,
pre_tensor_parallel_split=True if config.tensor_parallel_degree > 1 else False,
pre_tensor_parallel_split=True if config is not None and config.tensor_parallel_degree > 1 else False,
)
missing_keys = list(set(missing_keys) - set(new_keys))
unexpected_keys = list(set(unexpected_keys) - set(fused_keys))
Expand All @@ -1887,7 +1887,7 @@ def _fuse_or_split_keys(
ignore_mismatched_sizes,
)

if config.quantization_config.is_weight_quantize():
if hasattr(config, "quantization_config") and config.quantization_config.is_weight_quantize():
error_msgs = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
Expand Down
Loading