diff --git a/examples/RLHF/README.md b/examples/RLHF/README.md index 732a626061d6..478f8365833b 100644 --- a/examples/RLHF/README.md +++ b/examples/RLHF/README.md @@ -1,6 +1,6 @@ # RLHF PPO -提供了基于强化学习 PPO 算法对 LLM 进行人类偏好对齐的代码及完整使用示例。其中 PPO 代码实现细节参考了 [PKU-Alignment/safe-rlhf](https://github.com/PKU-Alignment/safe-rlhf)(PKU Beaver) 中的 PPO 实现,支持reward normalization、pretraining loss等常用的 PPO 稳定训练策略;示例使用 PKU-Alignment/safe-rlhf 提供的部分数据集和模型。后续将持续完善扩展,支持更好效果、更低成本、更高性能、更大规模的 RLHF 能力。 +提供了基于强化学习 PPO 算法对 LLM 进行人类偏好对齐的代码及完整使用示例,支持**3D 分布式并行训练以及 rollout 阶段使用预测优化进行生成加速**。其中 PPO 代码实现细节参考了 [PKU-Alignment/safe-rlhf](https://github.com/PKU-Alignment/safe-rlhf)(PKU Beaver) 中的 PPO 实现,支持reward normalization、pretraining loss等常用的 PPO 稳定训练策略;示例使用 PKU-Alignment/safe-rlhf 提供的部分数据集和模型。后续将持续完善扩展,支持更好效果、更低成本、更高性能、更大规模的 RLHF 能力。 ## 快速开始 @@ -14,6 +14,9 @@ ├── ppo_main.py # RLHF训练脚本 ├── ppo_config.json # RLHF训练配置文件 ├── ppo_trainer.py # RLHF训练执行器py脚本 +├── ppo_config.json # RLHF训练配置文件 +├── trainer_utils.py # Trainer补丁及工具py脚本 +├── infer_utils.py # 生成加速工具py脚本 ├── data # 数据集相关目录 │ └── base.py # 数据集基类及工具py文件 │ └── alpaca.py # alpaca(raw)数据集py文件 @@ -24,6 +27,10 @@ ├── models # 模型相关目录 │ └── score_model_utils.py # score model基类及工具py文件 │ └── score_model.py # score model模型定义py文件 +│ └── ppo_model_utils.py # PPO loss等模型策略py文件 +│ └── pp_model_utils.py # 流水线并行补丁及工具py文件 +│ └── model_pp.py # 流水线并行模型py文件 +│ └── infer_model_utils.py # 预测加速模型补丁及工具py文件 └── README.md ``` @@ -31,9 +38,9 @@ - Python >= 3.10 - PaddlePaddle >= 2.6.0 -- PaddleNLP >= 2.6.0 +- PaddleNLP 最新版本 -此外还需要安装以下依赖:`pip install rich` +如需使用生成加速功能,需要安装 [paddlenlp_ops](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/csrc) ,请使用 `git clone https://github.com/PaddlePaddle/PaddleNLP.git` 克隆 PaddleNLP 代码库并且将 PaddleNLP/llm 目录的路径加入 PYTHONPATH(后续将进行完善)。安装 paddlenlp_ops 后训练时将直接开启生成加速(开启流水线并行时不支持生成加速),否则使用原生动态图进行生成。 ### 数据准备 @@ -184,7 +191,8 @@ python -u -m paddle.distributed.launch reward_main.py ./reward_config.json RLHF 阶段需要 actor model、reference model、critic model、reward model 四个模型;actor-model/reference-model 使用 SFT 模型进行 initialize/frozen;critic-model/reward-model 使用 reward 模型进行 initialize/frozen (另外注意若 SFT 使用 LoRA 请先将 LoRA 权重合并)。这里使用 PKU-Alignment/PKU-SafeRLHF 提供的 SFT 模型([PKU-Alignment/alpaca-7b-reproduced](https://huggingface.co/PKU-Alignment/alpaca-7b-reproduced))和 reward 模型([PKU-Alignment/beaver-7b-v1.0-reward](https://huggingface.co/PKU-Alignment/beaver-7b-v1.0-reward),注意该模型只关注 helpful 未考量 harmless)作为示例,使用 `ppo_main.py` 脚本根据 `ppo_config.json` 进行 RLHF 训练。 ``` -python -u -m paddle.distributed.launch ppo_main.py ./ppo_config.json +# 类型提升 warning 暂时通过 loglevel 屏蔽,待后续修复 +GLOG_minloglevel=2 python -u -m paddle.distributed.launch ppo_main.py ./ppo_config.json ``` `ppo_config.json` 中的绝大部分参数释义同[LLM 精调](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm#2-%E7%B2%BE%E8%B0%83),不再赘述,重点给出以下参数配置及释义(使用 PKU-Alignment/PKU-SafeRLHF 中的默认值): @@ -210,7 +218,15 @@ python -u -m paddle.distributed.launch ppo_main.py ./ppo_config.json 另外所有 [`TrainingArguments` 支持参数配置](https://paddlenlp.readthedocs.io/zh/latest/trainer.html#trainingarguments)将为 actor-model 和 critic-model 的训练复用(如`sharding_stage`),除单独提供了 `critic_learning_rate/critic_weight_decay/critic_lr_scheduler_type/critic_warmup_ratio/critic_recompute` 这些参数支持为 critic-model 训练单独指定相应配置。actor-model 和 critic-model 的 checkpoints 将分别保存在 `outpt_dir` 所指定目录的 policy 和 value 文件夹下。 -当前示例中所用数据及规模 RLHF 训练基于 sharding stage3 使用 NVIDIA A100 80G 4卡/8卡训练验证。 +此外为了支持更高性、更大规模的 RLHF 训练提供了以下特殊参数配置,可以按需使用: +- `use_fusemt`:安装 paddlenlp_ops 后将在 rollout 生成时开启生成加速(开启流水线并行时不支持生成加速),通过此设置可以禁用生成加速。 +- `eval_mode`:支持为空或者设置为 "single"、"tensor_parallel";通常可以在使用流水线并行训练时设置为"tensor_parallel",以此在 rollout 生成阶段使用非流水线并行模型并进行生成加速。 +- `offload_level`:支持设置为"freeze_model"、"optimizer"、"train_model"或者同时使用(空格分隔),分别指示 reward+reference 两个冻结模型、actor+critic 两个训练模型的优化器状态和模型参数的 offload/reload,用于在不同阶段 model/optimizer 使用结束后及时 offload 并在下次使用时 reload 相应参数权重以节省显存。 + +另外注意,在使用流水线并行时(pipeline_parallel_degree大于1)建议将 `dataloader_drop_last` 设置为 true, 以此避免不同batch size带来的问题。 + + + ### 推理 diff --git a/examples/RLHF/comm_utils.py b/examples/RLHF/comm_utils.py new file mode 100644 index 000000000000..de077c65db31 --- /dev/null +++ b/examples/RLHF/comm_utils.py @@ -0,0 +1,403 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import paddle +import paddle.distributed as dist + +from paddlenlp.trainer.plugins.unified_checkpoint import flatten_list +from paddlenlp.trainer.trainer import Trainer, logger +from paddlenlp.trainer.utils.helper import nested_broadcast_tensor_with_empty +from paddlenlp.utils.distributed import distributed_gather + +global_dev_id = 0 if paddle.get_device() == "cpu" else int(paddle.get_device().split(":")[1]) + + +def offload_tensor_to_cpu(tensors): + if isinstance(tensors, dict): + for _, v in tensors.items(): + offload_tensor_to_cpu(v) + elif isinstance(tensors, paddle.Tensor): + if tensors.place.is_gpu_place(): + cpu_tensor = tensors._copy_to(paddle.CUDAPinnedPlace(), False) + tensors.value().get_tensor()._share_data_with(cpu_tensor.value().get_tensor()) + else: + logger.warning(f"Can't parse for type {type(tensors)}") + return tensors + + +def reload_tensor_to_gpu(tensors): + if isinstance(tensors, dict): + for _, v in tensors.items(): + reload_tensor_to_gpu(v) + elif isinstance(tensors, paddle.Tensor): + if tensors._is_initialized() and not tensors.place.is_gpu_place(): + gpu_tensor = tensors._copy_to(paddle.CUDAPlace(global_dev_id), False) + tensors.value().get_tensor()._share_data_with(gpu_tensor.value().get_tensor()) + else: + logger.warning(f"Can't parse for type {type(tensors)}") + return tensors + + +def cleanup_tensor_space(tensors): + if isinstance(tensors, dict): + for _, v in tensors.items(): + cleanup_tensor_space(v) + elif isinstance(tensors, paddle.Tensor): + tensors._clear_data() + else: + logger.warning(f"Can't parse for type {type(tensors)}") + return tensors + + +def data_group_split(tensors, group): + if group is None: + return tensors + if isinstance(tensors, (list, tuple)): + return type(tensors)(data_group_split(t, group) for t in tensors) + elif isinstance(tensors, dict): + new_dict = {} + for k, v in tensors.items(): + new_dict[k] = data_group_split(v, group) + return new_dict + elif isinstance(tensors, paddle.Tensor): + return tensors.split(group.nranks)[group.rank] + else: + logger.warning(f"Can't parse for type {type(tensors)}") + return tensors + + +def data_group_merge(tensors, group): + if group is None: + return tensors + + if isinstance(tensors, (list, tuple)): + return type(tensors)(data_group_merge(t, group) for t in tensors) + elif isinstance(tensors, dict): + new_dict = {} + for k, v in tensors.items(): + new_dict[k] = data_group_merge(v, group) + return new_dict + elif isinstance(tensors, paddle.Tensor): + tensor_list = [] + all_gather_nd(tensor_list, tensors, group=group, padded=True) + return paddle.concat(tensor_list) + else: + logger.warning(f"Can't parse for type {type(tensors)}") + return tensors + + +def group_rank_guard(group, rank=0): + def decorator(func): + def wrapper_func(*args, **kwargs): + if group.rank == rank: + ret = func(*args, **kwargs) + dist.barrier() + else: + ret = None + dist.barrier() + ret = nested_broadcast_tensor_with_empty(ret, group=group) + return ret + + return wrapper_func + + return decorator + + +def repad_rl_batches(batches, input_lengths): + if batches.get("position_ids", None) is not None: + v = batches["position_ids"] + for x in range(v.shape[0]): + v[x, input_lengths[x] :] = 1 + batches["position_ids"] = v + for key in list(batches.keys()): + if batches[key].shape[0] != input_lengths.shape[0]: + batches[key] = batches[key].mean() + + return batches + + +# https://stackoverflow.com/questions/12594148/skipping-execution-of-with-block +class SkipWithBlock(Exception): + pass + + +class SkipContextManager: + def __init__(self, skip): + self.skip = skip + + def __enter__(self): + if self.skip: + sys.settrace(lambda *args, **keys: None) + frame = sys._getframe(1) + frame.f_trace = self.trace + + def trace(self, frame, event, arg): + raise SkipWithBlock() + + def __exit__(self, type, value, traceback): + if type is None: + return # No exception + if issubclass(type, SkipWithBlock): + return True # Suppress special SkipWithBlock exception + + +def all_gather_nd(tensor_list, tensor, group=None, padded=False): + """ + Gathers tensor arrays of different lengths in a list. + The length dimension is 0. This supports any number of extra dimensions in the tensors. + All the other dimensions should be equal between the tensors. + + Args: + tensor (Tensor): Tensor to be broadcast from current process. + + Returns: + (Tensor): output list of tensors that can be of different sizes + """ + if len(tensor.shape) == 0: + tensor = tensor.reshape([1]) + dist.all_gather(tensor_list, tensor, group=group) + return tensor_list + + world_size = group.nranks + local_size = paddle.to_tensor(tensor.shape, place=tensor.place) + all_sizes = [paddle.zeros_like(local_size) for _ in range(world_size)] + dist.all_gather(all_sizes, local_size, group=group) + + # max_length = max(size[0] for size in all_sizes) + + # length_diff = max_length.item() - local_size[0].item() + # if length_diff: + # pad_size = (length_diff, *tensor.size()[1:]) + # padding = paddle.zeros(pad_size, place=tensor.place(), dtype=tensor.dtype) + # tensor = padle.concat((tensor, padding)) + + max_length = max(size[-1] for size in all_sizes) + + length_diff = max_length.item() - local_size[-1].item() + if length_diff: + pad_size = (*tensor.shape[:-1], length_diff) + padding = paddle.zeros(pad_size, dtype=tensor.dtype) + tensor = paddle.concat([tensor, padding], axis=-1) + + all_tensors_padded = [] + dist.all_gather(all_tensors_padded, tensor, group=group) + # all_tensors = [] + if padded: + tensor_list.extend(all_tensors_padded) + return all_tensors_padded + + for tensor_, size in zip(all_tensors_padded, all_sizes): + tensor_list.append(tensor_[..., : size[-1]]) + return tensor_list + + +def export_evaluate_model(self: Trainer, train_model, eval_model, **kwargs): + if eval_model is None: + return None + + with_offload = kwargs.pop("with_offload", False) + train_tp_size = max(train_model.config.tensor_parallel_degree, 1) + eval_tp_size = max(eval_model.config.tensor_parallel_degree, 1) + eval_tp_rank = max(eval_model.config.tensor_parallel_rank, 0) + + hcg = dist.fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + pp_group = hcg.get_pipe_parallel_group() + sd_group = hcg.get_sharding_parallel_group() + dp_group = hcg.get_data_parallel_group() + + global_rank = paddle.distributed.get_rank() + + train_state_dict = train_model.state_dict() + eval_state_dict = eval_model.state_dict() + + if dp_group.rank <= 0 and sd_group.rank <= 0: + train_pp_size = pp_group.nranks + if eval_tp_size > 1 and train_tp_size != eval_tp_size: + raise ValueError("Only support for the same tensor_parallel_degree for train and eval model for now.") + + # 单卡情况 + # tp->single + # tp+pp -> single + if eval_tp_size == 1: + if train_pp_size == 1 and train_tp_size > 1: + # tp ->single + logger.error("using tp to single eval model.") + # state = train_model.merge_tensor_parallel() + tp_actions = train_model.get_tensor_parallel_convert_actions( + train_model.config, + loaded_state_dict_keys=eval_state_dict.keys(), + is_split=False, + ignore_error=False, + ) + + is_dst = global_rank == 0 + for key in eval_state_dict.keys(): + tensor = train_state_dict[key] + if key in tp_actions: + ret = distributed_gather(tensor, dst=0, group=tp_group, offload=False) + action = tp_actions.pop(key) + tensor = action(ret) if is_dst else None + else: + tensor = tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None + + if tensor is not None: + eval_state_dict[key].set_value(tensor) + + if not eval_state_dict[key]._is_initialized(): + v = eval_state_dict[key] + t = paddle._C_ops.full_like(v, 0, v.dtype, paddle.CUDAPlace(global_dev_id)) + v.get_tensor()._share_data_with(t.get_tensor()) + + if with_offload: + offload_tensor_to_cpu(train_state_dict[key]) + else: + # single to single + # tp+pp -> single + raise ValueError("Not support yet.") + + def create_send_recv_table(train_keys, eval_keys): + recv_table = [] + send_table = [] + if pp_group.rank == 0: + for key in eval_keys: + recv_table.append((key, global_rank)) + + for key in train_keys: + send_table.append((key, global_rank)) + + all_recv, all_send = [], [] + paddle.distributed.all_gather_object(all_recv, [recv_table], group=pp_group) + paddle.distributed.all_gather_object(all_send, [send_table], group=pp_group) + all_recv = flatten_list(all_recv) + all_send = flatten_list(all_send) + + send_dict = {} + for k, v in all_send: + send_dict[k] = v + + table = [] + for k, v in all_recv: + # key, send, recv + table.append([k, send_dict.pop(k), v]) + assert len(send_dict) == 0, f"Some key can't be recv {send_dict.keys()}" + return table + + # pp0tp0 -> pp0tp0 + # pp0tp1 -> pp0tp1 + # pp1tp0 -> pp0tp0 + # pp1tp1 -> pp0tp1 + + # tp情况 + # tp+pp->tp + self.timers and self.timers("export-merge-pp").start() + if eval_tp_size > 1 and train_pp_size > 1: + table = create_send_recv_table(train_state_dict.keys(), eval_state_dict.keys()) + + for key, src_rank, dst_rank in table: + # Init tensor for model is cleaned + if not eval_state_dict[key]._is_initialized(): + v = eval_state_dict[key] + t = paddle._C_ops.full_like(v, 0, v.dtype, paddle.CUDAPlace(global_dev_id)) + v.get_tensor()._share_data_with(t.get_tensor()) + + if src_rank == dst_rank and global_rank == src_rank: + eval_state_dict[key].copy_(train_state_dict[key], True) + else: + if global_rank == src_rank: + dist.stream.send(train_state_dict[key], dst=dst_rank) + + if global_rank == dst_rank: + dist.stream.recv(eval_state_dict[key], src=src_rank) + + # Offload train model if need + if global_rank == src_rank and with_offload: + offload_tensor_to_cpu(train_state_dict[key]) + + self.timers and self.timers("export-merge-pp").stop() + self.timers and self.timers("export-broadcast-pp").start() + if pp_group.nranks > 1: + paddle.distributed.parallel.sync_params_buffers( + eval_model, comm_group=pp_group, src_rank=pp_group.ranks[0], fuse_params=False + ) + self.timers and self.timers("export-broadcast-pp").stop() + else: + # 其他 DP rank 的state dict, 适配 offload 和初始化 + self.timers and self.timers("export-offload-and-init").start() + if with_offload: + for key in list(train_state_dict.keys()): + offload_tensor_to_cpu(train_state_dict[key]) + for k, v in eval_state_dict.items(): + if not v._is_initialized(): + t = paddle._C_ops.full_like(v, 0, v.dtype, paddle.CUDAPlace(global_dev_id)) + v.get_tensor()._share_data_with(t.get_tensor()) + self.timers and self.timers("export-offload-and-init").stop() + + paddle.distributed.barrier() + self.timers and self.timers("export-broadcast-sd-dp").start() + if eval_tp_size == 1: + for _, tensor in eval_state_dict.items(): + paddle.distributed.broadcast(tensor, src=0, group=None, sync_op=True) + else: + if sd_group.nranks > 1: + if dp_group.rank <= 0: + paddle.distributed.parallel.sync_params_buffers( + eval_model, comm_group=sd_group, src_rank=sd_group.ranks[0], fuse_params=False + ) + if dp_group.nranks > 1: + paddle.distributed.parallel.sync_params_buffers( + eval_model, comm_group=dp_group, src_rank=dp_group.ranks[0], fuse_params=False + ) + self.timers and self.timers("export-broadcast-sd-dp").stop() + # paddle.save(eval_state_dict, f"./tmp/eval_{sd_group.rank}_tp_{eval_tp_rank}_pp_{pp_group.rank}.pdparams") + # paddle.save(train_state_dict, f"./tmp/train_{sd_group.rank}_tp_{tp_group.rank}_pp_{pp_group.rank}.pdparams") + # paddle.distributed.barrier() + # exit(-1) + + old_dp_workers = self.args.world_size // (max(sd_group.nranks, 1) * max(dp_group.nranks, 1)) + group_nums = self.args.logical_process_index // old_dp_workers * eval_tp_size + eval_tp_rank + + if not hasattr(self, "_policy_model_eval_group") or self._policy_model_eval_group is None: + self._policy_model_eval_group = create_data_trans_group(global_rank, group_nums) + + return None + + +def create_data_trans_group(global_rank, group_nums): + all_split_table = [] + paddle.distributed.all_gather_object(all_split_table, [(global_rank, group_nums)]) + all_split_table = flatten_list(all_split_table) + split_dict = {} + for k, v in all_split_table: + split_dict[k] = v + + split_ranks = {} + for k, v in all_split_table: + if v in split_ranks: + split_ranks[v].append(k) + else: + split_ranks[v] = [k] + + group = None + for k, ranks in split_ranks.items(): + gp = paddle.distributed.new_group(ranks=ranks) + if global_rank in ranks: + group = gp + + return group + + +Trainer.export_evaluate_model = export_evaluate_model diff --git a/examples/RLHF/infer_utils.py b/examples/RLHF/infer_utils.py new file mode 100644 index 000000000000..d0667aefe061 --- /dev/null +++ b/examples/RLHF/infer_utils.py @@ -0,0 +1,383 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import inspect +import types +from contextlib import contextmanager + +import paddle +import paddle.distributed as dist +from comm_utils import cleanup_tensor_space, offload_tensor_to_cpu, reload_tensor_to_gpu +from paddle.utils import try_import +from trainer_utils import guard_set_args + +import paddlenlp +from paddlenlp.trainer.trainer import Trainer, logger +from paddlenlp.transformers import PretrainedModel, PretrainedTokenizer +from paddlenlp.transformers.model_utils import dtype_guard + + +class Predictor: + def __init__(self, config, model: PretrainedModel = None, tokenizer: PretrainedTokenizer = None): + self.model_config = model.config + self.config = config + self.tokenizer = tokenizer + self.model = model + self.is_available = False + self._weights_mapping = None + # TODO(guosheng): Removde dependency on llm.Predictor + # 1. buffer_maker creates caches and other buffer inputs can be shared + # among multi time prediction. define caches and extra inputs creation + # method instead of using predictor.__init__ + # 2. inputs_processer creates caches and other inputs can be shared among + # multi time prediction. define caches and extra inputs creation method + # instead of using predictor.__init__ + from predictor import InferencePredictorMixin + + self._buffer_maker = types.MethodType(InferencePredictorMixin.__init__, self) + self._inputs_processer = types.MethodType(InferencePredictorMixin._preprocess, self) + + @staticmethod + def create_predictor(trainer): + from predictor import ( + PdArgumentParser, + PredictorArgument, + get_model_max_position_embeddings, + ) + + # create infer model + # NOTE: infer model use static name param_attr to create and cannot be + # created multiple times. + def create_infer_model(model, dtype, set_state=False): + from models.infer_model_utils import patch_infer_generate + + # apply patches to make FuseMT adapt + patch_infer_generate( + eos_token_id=trainer.tokenizer.eos_token_id, pad_token_id=trainer.tokenizer.pad_token_id + ) + config = copy.deepcopy(model.config) + hcg = dist.fleet.get_hybrid_communicate_group() # may differ with training + config.tensor_parallel_degree = hcg.get_model_parallel_world_size() + config.tensor_parallel_rank = hcg.get_model_parallel_rank() + config.weight_only_quant_bits = -1 + config.quant_type = None + config.use_cachekv_int8 = False + config.single_card_ptq = True + infer_model_cls = getattr(paddlenlp.experimental.transformers, model.__class__.__name__ + "InferenceModel") + # ori_init_weights = infer_model_cls.init_weights + # infer_model_cls.init_weights = lambda self: None + with dtype_guard(dtype): + infer_model = infer_model_cls(config) + # infer_model_cls.init_weights = ori_init_weights + + if set_state: + state_dict = {} + for k, v in model.state_dict().items(): + # state_dict[k] = np.from_dlpack(paddle.utils.dlpack.to_dlpack(v)) + state_dict[k] = v.numpy() + infer_model.set_state_dict(state_dict) + return infer_model + + # to avoid oom, clear param of infer_model imediately + ori_creat_param = paddle.nn.Layer.create_parameter + + def _create_param(self, *args, **kwargs): + param = ori_creat_param(self, *args, **kwargs) + param._clear_data() + # param._clear() + return param + + paddle.nn.Layer.create_parameter = _create_param + # trainer might use an extra model instead of trainer.model for eval + eval_model = getattr(trainer, "_inner_eval_model", None) + infer_model = create_infer_model(trainer.model if eval_model is None else eval_model, dtype=trainer.amp_dtype) + paddle.nn.Layer.create_parameter = ori_creat_param + # for k, v in infer_model.state_dict().items(): + # v._clear() + + # create predictor + parser = PdArgumentParser((PredictorArgument,)) + predictor_args = parser.parse_dict( + { + "src_length": get_model_max_position_embeddings( # can be changed dynamically by predictor.input_length + trainer.model.config if eval_model is None else eval_model.config + ), + "max_length": trainer.args.max_length, + "dtype": trainer.amp_dtype, + "batch_size": trainer.args.per_device_train_batch_size, + # infer model do not support top_k, and differ with non-infer model + # generation which gets default top_K=50 using generation_config.top_k + "top_p": trainer.args.top_p, + "temperature": trainer.args.temperature, + "repetition_penalty": trainer.args.repetition_penalty, + } + )[0] + policy_predictor = Predictor(predictor_args, model=infer_model, tokenizer=trainer.tokenizer) + return policy_predictor + + def _create_caches(self): + """inputs can be reused among multiple predictions, such as cache""" + if hasattr(self, "cache_kvs_shape"): # has created cache + input_length = getattr(self, "input_length", 0) + # TODO(guosheng): better way to get history max cahce length, we can + # not get cahce length form cache tensor when not know cache layout + if input_length <= self.config.src_length: # reuse cahce + return + else: # create longer cache + self._clear_caches() + self.config.src_length = getattr(self, "input_length", self.config.src_length) + if not hasattr(self, "_buffer_attrs"): + pre_attrs = set(self.__dict__.keys()) + self.cache_kvs_shape = self.model.get_cache_kvs_shape( + self.model_config, self.config.batch_size, self.config.total_max_length + ) + self._buffer_maker(self.config, self.tokenizer) + if not hasattr(self, "_buffer_attrs"): + self._buffer_attrs = set(self.__dict__.keys()) - pre_attrs + + def _clear_caches(self): + # del or offload + for attr in self._buffer_attrs: + delattr(self, attr) + + def disable(self, model, onload_model=True): + # clear caches + self._clear_caches() + # clear params + for _, param in self.model.state_dict().items(): + param._clear_data() + # param._clear() + if onload_model: + model.to(paddle.device.get_device()) + self.is_available = False + + def enable(self, model, offload_model=True): + if self.is_available: + return + # set params + self.set_state_dict(model, offload_model) + self.is_available = True + + @paddle.no_grad() + def set_state_dict(self, model, offload_model=True): + offload_place = paddle.CUDAPinnedPlace() + state_dict = {} + for k, v in model.state_dict().items(): + state_dict[k] = v + + if getattr(self, "_weights_mapping", None) is None: + self._weights_mapping = self.model.get_weights_mapping() + + # non_share_params = [] + for k, v in self._weights_mapping.items(): + param, (convert_fun, args) = k, v + args = [state_dict[name] for name in args] + value = convert_fun(*args) + if offload_model: + for arg in args: + # shared params no need to offload + if value is not arg: + cpu_arg = arg._copy_to(offload_place, blocking=False) + cpu_arg._share_buffer_to(arg) + if not isinstance(value, paddle.Tensor): + param.set_value(value) + # elif isinstance(value.place, paddle.CUDAPlace): + elif value.place.is_gpu_place(): + # NOTE: _share_buffer_to seems do not work + # value._share_buffer_to(param) + # value._share_underline_tensor_to(param) + param.get_tensor()._share_data_with(value.get_tensor()) + else: + param.copy_(value, True) + + paddle.device.cuda.synchronize() + + def _preprocess(self, source): + # make cache when infer happens to get actual shape to save memory + self._create_caches() + with guard_set_args(self.config, {"src_length": getattr(self, "input_length", self.config.src_length)}): + inputs = self._inputs_processer(source) + # We want to use a defined input_length to create cache and input_ids. + # However predictor could not use a specified length to pad currently. + # Thus we use this way to let get the actual input length. + self.infer_input_length = inputs["input_ids"].shape[-1] + return inputs + + @paddle.no_grad() + def _infer(self, inputs): + for key in inputs.keys(): + if paddle.is_tensor(inputs[key]): + continue + if isinstance(inputs[key], list): + if paddle.is_tensor(inputs[key]): + continue + inputs[key] = [paddle.to_tensor(item) for item in inputs[key]] + else: + inputs[key] = paddle.to_tensor(inputs[key]) + + inputs["cache_kvs"] = self.cache_kvs + return self.model.generate(**inputs) + + def _postprocess(self, predictions): + return predictions + + @paddle.no_grad() + def predict(self, input_texts: str | list[str]): + tokenized_source = self._preprocess(input_texts) + predictions = self._infer(tokenized_source) + decoded_predictions = self._postprocess(predictions) + return decoded_predictions + + +policy_predictor: Predictor = None + + +@contextmanager +def infer_guard(trainer, offload_model=True): + # trainer might use an extra model instead of trainer.model for eval + eval_model = getattr(trainer, "_inner_eval_model", None) + model = trainer.model if eval_model is None else eval_model + + # PipelineParallel does not support inference speedup + if not getattr(trainer, "use_fusemt", False) or isinstance( + model, (dist.fleet.meta_parallel.PipelineLayer, dist.fleet.model.PipelineParallel) + ): + yield + return + + try: + try_import("paddlenlp_ops") + except: + logger.warning("paddlenlp_ops does not exist, please install paddlenlp_ops for generation speedup.") + yield + return + + global policy_predictor + if policy_predictor is None: + policy_predictor = Predictor.create_predictor(trainer) + if not policy_predictor.is_available: + policy_predictor.enable(model, offload_model=offload_model) + + # TODO(guosheng): patch for dist.all_recude to use tp group, fix it later + ori_all_reduce = dist.all_reduce + ori_broadcast = dist.broadcast + hcg = dist.fleet.get_hybrid_communicate_group() + dist.all_reduce = lambda x: ori_all_reduce(x, group=hcg.get_model_parallel_group()) + dist.broadcast = lambda x, rank: ori_broadcast( + x, src=hcg.get_model_parallel_group_src_rank(), group=hcg.get_model_parallel_group() + ) + yield + dist.all_reduce = ori_all_reduce + dist.broadcast = ori_broadcast + + policy_predictor.disable(model, onload_model=offload_model) + + +class InferEvalModel: + """For faster generation, not support PipelineParallel yet.""" + + def __init__(self, trainer: Trainer): + # trainer might use an extra model instead of trainer.model for eval + eval_model = getattr(trainer, "_inner_eval_model", None) + self.model: PretrainedModel = trainer.model if eval_model is None else eval_model + self.tokenizer: PretrainedTokenizer = trainer.tokenizer + self.trainer = trainer + + def enable(self): + trainer = self.trainer + if trainer.model is not self.model: + trainer.export_evaluate_model( + trainer.model, + self.model, + with_offload="train_model" in trainer.args.offload_level, + ) + else: + reload_tensor_to_gpu(self.model.state_dict()) + + def disable(self): + trainer = self.trainer + if trainer.model is not self.model: + cleanup_tensor_space(self.model.state_dict()) + else: + offload_tensor_to_cpu(self.model.state_dict()) + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.model, name) + + def eval(self): + self.model.eval() + + def train(self): + self.model.train() + + def __call__(self, *args, **kwargs): + # assert model is on GPU + assert policy_predictor is None or not policy_predictor.is_available + return self.model(*args, **kwargs) + + def generate(self, *args, **kwargs): + if policy_predictor is None or not policy_predictor.is_available: + return self.model.generate(*args, **kwargs) + + arg_dict = inspect.signature(self.model.generate).bind(*args, **kwargs).arguments + input_ids = arg_dict["input_ids"] + generation_config = arg_dict["generation_config"] + # convert text and tokenize again to convert left padding to right padding + # remove this if inputs is right padding + # TODO(guosheng): allow to use right padding to infer directly + prompts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + # decoded prompts has been applied with chat_template + # NOTE(guosheng): Whether to add special token should be checked, None + # chat_template would not add special token in predictor, since it assumes + # chat_template includes special tokens. While Beaver dataset tokenization + # does not use chat_template, it uses hard coded template which excludes + # special tokens. + with guard_set_args( + policy_predictor.tokenizer, + { + # predictor use right padding for infer model by default + # "padding_side": "right", + # "chat_template": None + }, + ): + # NOTE: right padding in predictor according to prompt might have a + # different length with input_ids, espically when input_ids has more + # paddings than the necessary. Thus pass input_length to predictor to: + # 1. use a consistent length to replace input_ids back to output to + # keep the same padding format. however predictor could not use a + # specified length to pad currently + # 2. allow to use a dynamic length for memory efficiency (by a smaller + # cache) + policy_predictor.input_length = input_ids.shape[-1] + outputs = policy_predictor.predict(prompts) + + if generation_config.trunc_input: + outputs = (outputs[0][:, policy_predictor.infer_input_length :],) + return outputs + + if policy_predictor.input_length != policy_predictor.infer_input_length: + outputs = (paddle.concat([input_ids, outputs[0][:, policy_predictor.infer_input_length :]], axis=-1),) + return outputs + + outputs = (outputs[0],) + if self.tokenizer.padding_side == "left": + # convert back to left padding inputs + outputs[0][:, : input_ids.shape[-1]] = input_ids + return outputs diff --git a/examples/RLHF/models/infer_model_utils.py b/examples/RLHF/models/infer_model_utils.py new file mode 100644 index 000000000000..3d63fe52aa9b --- /dev/null +++ b/examples/RLHF/models/infer_model_utils.py @@ -0,0 +1,185 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2023 PKU-Alignment Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for inference model.""" + +import numpy as np +import paddle + + +def patch_paddlenlp_ops(eos_token_id, pad_token_id): + import paddlenlp_ops + + paddlenlp_ops.save_with_output = lambda *args, **kwargs: None + + # TODO(guosheng): update the custom op code directly. + ori_set_ends = paddlenlp_ops.set_stop_value_multi_ends + + def _set_ends(topk_ids, stop_flags, end_ids, mode): + # infer model uses eos_token_id to pad and discriminate ending, + # patch to use pad_token_id to pad to unify with non-infer model. + topk_ids_out, stop_flags_out = ori_set_ends(topk_ids, stop_flags, end_ids, mode) + if pad_token_id != eos_token_id: + topk_ids_out = paddle.where(stop_flags, pad_token_id, topk_ids_out) + return topk_ids_out, stop_flags_out + + paddlenlp_ops.set_stop_value_multi_ends = _set_ends + + +def patch_infer_generate(eos_token_id, pad_token_id): + """patches for inference model to make FuseMT adapt""" + # patch paddlenlp_ops, maybe update the custom op code directly later + # NOTE: should patch paddlenlp_ops before infer model import + patch_paddlenlp_ops(eos_token_id, pad_token_id) + + # patch get_weights_mapping for InferenceModel + patch_infer_model() + + # patch GenerationInferenceModel.sample + from paddlenlp.experimental.transformers.generation_utils import ( + GenerationInferenceModel, + ) + + ori_update_model_kwargs = GenerationInferenceModel.update_model_kwargs_for_generation + + def _update_model_kwargs(self, *args, **kwargs): + # update_model_kwargs_for_generation only returns , hack to. + model_kwargs = ori_update_model_kwargs(self, *args, **kwargs) + next_tokens = model_kwargs["next_tokens"] + all_input_ids = paddle.concat([model_kwargs["all_input_ids"], next_tokens], axis=1) + model_kwargs["next_tokens"] = all_input_ids + model_kwargs["all_input_ids"] = None + return model_kwargs + + GenerationInferenceModel.update_model_kwargs_for_generation = _update_model_kwargs + + +_model_weights_mapping_dict = {} + + +def register_model(model_cls_name): + def mark_cls_name(func): + # Do not register here although we can, otherwise infer model would import + # before paddlenlp_ops. + _model_weights_mapping_dict[model_cls_name] = func + return func + + return mark_cls_name + + +def patch_infer_model(): + import paddlenlp.experimental.transformers as infer_transformers + + for model_cls_name, get_weights_mapping in _model_weights_mapping_dict.items(): + model_cls = getattr(infer_transformers, model_cls_name) + model_cls.get_weights_mapping = get_weights_mapping + + +@register_model("LlamaForCausalLMInferenceModel") +def get_weights_mapping(self): + """model to infer model""" + head_size = self.config.hidden_size // self.config.num_attention_heads + + def _concat_qkv(q, k, v): + if isinstance(q, paddle.Tensor): + concated_qkv_weight = paddle.concat([q, k, v], axis=-1).T.reshape( + [ + 3 * (self.config.num_attention_heads // self.config.tensor_parallel_degree) * (head_size), + self.config.hidden_size, + ] + ) + else: + concated_qkv_weight = ( + np.concatenate( + [q, k, v], + axis=-1, + ) + .transpose(1, 0) + .reshape( + 3 * (self.config.num_attention_heads // self.config.tensor_parallel_degree) * (head_size), + self.config.hidden_size, + ) + ) + + return concated_qkv_weight + + def _concat_ffn1(w1, w2): + if isinstance(w1, paddle.Tensor): + concated_ffn1_weight = paddle.concat([w1, w2], axis=-1) + else: + concated_ffn1_weight = np.concatenate([w1, w2], axis=-1) + return concated_ffn1_weight + + identity = lambda x: x + + weight_mapping = {} + weight_mapping[self.lm_head.weight] = [ + identity, + [ + "lm_head.weight", + ], + ] + weight_mapping[self.llama.embed_tokens.weight] = [ + identity, + [ + "llama.embed_tokens.weight", + ], + ] + weight_mapping[self.llama.norm.weight] = [ + identity, + [ + "llama.norm.weight", + ], + ] + for idx in range(self.config.num_hidden_layers): + weight_mapping[self.llama.transformer_block.qkv_weights[idx]] = [ + _concat_qkv, + [ + f"llama.layers.{idx}.self_attn.q_proj.weight", + f"llama.layers.{idx}.self_attn.k_proj.weight", + f"llama.layers.{idx}.self_attn.v_proj.weight", + ], + ] + weight_mapping[self.llama.transformer_block.ffn1_weights[idx]] = [ + _concat_ffn1, + [ + f"llama.layers.{idx}.mlp.gate_proj.weight", + f"llama.layers.{idx}.mlp.up_proj.weight", + ], + ] + weight_mapping[self.llama.transformer_block.linear_weights[idx]] = [ + identity, + [ + f"llama.layers.{idx}.self_attn.o_proj.weight", + ], + ] + weight_mapping[self.llama.transformer_block.ffn2_weights[idx]] = [ + identity, + [ + f"llama.layers.{idx}.mlp.down_proj.weight", + ], + ] + weight_mapping[self.llama.transformer_block.ln_scales[idx]] = [ + identity, + [ + f"llama.layers.{idx}.input_layernorm.weight", + ], + ] + weight_mapping[self.llama.transformer_block.ffn_ln_scales[idx]] = [ + identity, + [ + f"llama.layers.{idx}.post_attention_layernorm.weight", + ], + ] + return weight_mapping diff --git a/examples/RLHF/models/model_pp.py b/examples/RLHF/models/model_pp.py new file mode 100644 index 000000000000..57bcf9f465cf --- /dev/null +++ b/examples/RLHF/models/model_pp.py @@ -0,0 +1,261 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +from paddle.distributed.fleet.meta_parallel import LayerDesc + +from paddlenlp.transformers import LlamaForCausalLM, LlamaForCausalLMPipe +from paddlenlp.transformers.llama.modeling import LlamaDecoderLayer +from paddlenlp.transformers.llama.modeling_pp import ( + LlamaRMSNormPipe, + parse_args, + return_args, +) + +from .pp_model_utils import fwd_args_to_dict, get_expected_keys, pad_batches_inputs +from .ppo_model_utils import ( + RLHFPPOMixedLoss, + RLHFValueLoss, + create_loss, + make_position_ids, +) +from .score_model_utils import ScoreModelMixin + +# patches for base pipe model +# non-pipe model class, can be used to parse and convert forward args +# mainly used for generation with PipelienParallel model +LlamaForCausalLMPipe._non_pipe_model_class = LlamaForCausalLM +LlamaForCausalLMPipe._non_pipe_decoder_layer_class = LlamaDecoderLayer + + +class LlamaPolicyPipe(LlamaForCausalLMPipe): + # TODO(guosheng): maybe make a Mixin is better + + @fwd_args_to_dict + def _prepare_pipeline_inputs_func(self, inputs): + # first_stage_keys = ["input_ids", "attention_mask"] + first_stage_keys = ["input_ids", "attention_mask", "position_ids"] + # last_stage_keys = [ + # "labels", "input_ids", "log_probs", "advantages", "sequence_mask" + # ] + # TODO(guosheng): make input keys same with model arg names, maybe we + # can use inspect and set as global var which can then be used here and + # in PPOTrainer. + last_stage_keys = ["labels", "input_ids", "old_log_probs", "reward_advantages", "sequence_mask"] + + if type(inputs) is dict: + # for left padding, position_ids is nececessary + if "position_ids" not in inputs: + inputs["position_ids"] = make_position_ids(inputs["attention_mask"]) + # ppo-loss and ptx-loss need different labels, and data iter provides + # corrensponding data, thus add the not provided fields here. + # policy trian and infer has different inputs, infer uses position_ids. + # for key in last_stage_keys: + for key in first_stage_keys + last_stage_keys: + if key not in inputs: + inputs[key] = None + return [ + get_expected_keys(inputs, first_stage_keys), + get_expected_keys(inputs, last_stage_keys), + ] + + for data in inputs: + # for key in last_stage_keys: + for key in first_stage_keys + last_stage_keys: + if key not in data: + if key == "position_ids": + data[key] = make_position_ids(data["attention_mask"]) + continue + data[key] = None + # keys = list(inputs[0].keys()) + inputs_batch = {key: [data.get(key) for data in inputs] for key in first_stage_keys + last_stage_keys} + # NOTE(guosheng): PipelineParallel requires send/recv tensors among + # micro-batches/accu-steps have the same shape. Thus pad here, maybe + # should make data collator do padding and pad optionally here, since + # padding strategy may not be clear here. + # 1. For input_ids/attention_mask/labels (prompt+target) padding: + # Some data fields, such as input_ids/attention_mask/labels, should + # have same shape after padding, and each of them cannot pad only + # according to its own max length which might be different since the + # filed value is None for different batches/tasks. + src_tgt_keys = ["input_ids", "attention_mask", "labels", "position_ids"] + max_len = max([x.shape[-1] for x in inputs_batch["input_ids"]]) + pad_len = [max_len - x.shape[-1] for x in inputs_batch["input_ids"]] + for key in src_tgt_keys: + # Do not pad position_ids with 0 since 0s in position_ids has special + # usage in reward model. We use 1 to pad. + padding_value = self._ignore_index if key == "labels" else 1 if key == "position_ids" else 0 + inputs_batch[key] = pad_batches_inputs(inputs_batch[key], padding_value, pad_len=pad_len) + # 2. For old_log_probs/reward_advantages/sequence_mask (target) padding: + # hard to pad acorss batches, think in some cases one batch might have the + # longest prompt+target length but the shortest target lengh, which might + # cause mismatch between inputs with prompt+target length and labels with + # target length. NOTE: however trick can be used here, label fields with + # target length such as old_log_probs/reward_advantages/sequence_mask do + # not need to join comm and thus there is no need to keep same shape among + # batches of accumulation steps, they just need to pad as prompt+target + # fields such as input_ids. + tgt_keys = ["old_log_probs", "reward_advantages", "sequence_mask"] + for key in tgt_keys: + padding_value = 0 + inputs_batch[key] = pad_batches_inputs(inputs_batch[key], padding_value, pad_len=pad_len) + # for key, value in inputs_batch.items(): + # padding_value = self._ignore_index if key == "labels" else 0 + # max_len = max_len if key in [ + # "input_ids", "attention_mask", "labels" + # ] else None + # inputs_batch[key] = pad_batches_inputs(value, padding_value, max_len) + return [ + get_expected_keys(inputs_batch, first_stage_keys), + get_expected_keys(inputs_batch, last_stage_keys), + ] + + def __init__(self, config, **kwargs): + # NOTE: make _sequential_layers/_single_to_pp_mapping/_pp_to_single_mapping + # instance attrs instead of class attrs to support more than one pipeline + # models. Maybe make all sequential_layers add once. + self._sequential_layers = [] + self._single_to_pp_mapping = None + self._pp_to_single_mapping = None + # To be consistent with score model init and allow hyper-param be passed + # using __init__/from_pretrained + self._init_kwargs = kwargs + super().__init__(config) + self._ignore_index = self._loss_fn.sft_criterion.ignore_index + + def get_loss_fn(self, config): + return create_loss(RLHFPPOMixedLoss, config, self._init_kwargs) + + @property + def head_out_meta(self): + """mainly for eval/generation with PipelineParallel""" + # None means to use actual data info + return paddle.static.InputSpec(shape=[None, None, self.config.vocab_size], dtype=None) + + +class _LlamaRMSNormPipe(LlamaRMSNormPipe): + """ + We need position_ids for reward model, so wrap LlamaRMSNormPipe to pass position_ids + """ + + def __init__(self, config): + super().__init__(config) + + def forward(self, args): + hidden_states, attention_mask, position_ids, alibi = parse_args(args) + return return_args(self.norm(hidden_states), attention_mask, position_ids) + + +# LayerDesc of PipelineParallel requires head to be a nn.Layer +class ValueHead(nn.Layer, ScoreModelMixin): + def __init__(self, config, **kwargs): + super().__init__() + self.config = config + self.init_score_head(config, hidden_size=config.hidden_size, **kwargs) + + def forward(self, args): + # attention_mask passed from pre-stage is shaped (bs, 1, seq_len, seq_len) + hidden_state, attention_mask, position_ids, alibi = parse_args(args) + outputs = self.get_score( + hidden_state, attention_mask=attention_mask, position_ids=position_ids, return_dict=True + ) + return outputs + + +class LlamaValuePipe(LlamaForCausalLMPipe): + # TODO(guosheng): maybe make a Mixin is better + + @fwd_args_to_dict + def _prepare_pipeline_inputs_func(self, inputs): + # ValueHead/get_score needs original attention_mask or position_ids, + # while attention_mask passed from pre-stage is not the original, thus + # hack for position_ids here. + # Maybe add position_ids into inputs later and use position_ids instead + # of attention_mask to get score not only for pipeline parallel. + first_stage_keys = ["input_ids", "attention_mask", "position_ids"] + # TODO(guosheng): make input keys same with model arg names, maybe we + # can use inspect and set as global var which can then be used here and + # in PPOTrainer. + last_stage_keys = ["old_reward_values", "reward_returns", "sequence_mask"] + + if type(inputs) is dict: + if "position_ids" not in inputs: + inputs["position_ids"] = make_position_ids(inputs["attention_mask"]) + + return [ + get_expected_keys(inputs, first_stage_keys), + get_expected_keys(inputs, last_stage_keys), + ] + + for data in inputs: + if "position_ids" not in data: + data["position_ids"] = make_position_ids(data["attention_mask"]) + # keys = list(inputs[0].keys()) + inputs_batch = {key: [data.get(key) for data in inputs] for key in first_stage_keys + last_stage_keys} + # 1. For input_ids/attention_mask (prompt+target) padding: + # src_tgt_keys = ["input_ids", "attention_mask"] + src_tgt_keys = ["input_ids", "attention_mask", "position_ids"] + max_len = max([x.shape[-1] for x in inputs_batch["input_ids"]]) + pad_len = [max_len - x.shape[-1] for x in inputs_batch["input_ids"]] + for key in src_tgt_keys: + # Do not pad position_ids with 0 since 0s in position_ids has special + # usage in reward model. We use 1 to pad. + padding_value = self._ignore_index if key == "labels" else 1 if key == "position_ids" else 0 + inputs_batch[key] = pad_batches_inputs(inputs_batch[key], padding_value, pad_len=pad_len) + # 2. For old_reward_values/reward_returns/sequence_mask (target) padding: + tgt_keys = ["old_reward_values", "reward_returns", "sequence_mask"] + for key in tgt_keys: + padding_value = 0 + inputs_batch[key] = pad_batches_inputs(inputs_batch[key], padding_value, pad_len=pad_len) + # for key, value in inputs_batch.items(): + # inputs_batch[key] = pad_batches_inputs(value, padding_value=0) + # if "position_ids" not in inputs[0]: + # inputs_batch["position_ids"] = [ + # make_position_ids(attention_mask) for attention_mask in inputs_batch["attention_mask"] + # ] + return [ + get_expected_keys(inputs_batch, first_stage_keys), + get_expected_keys(inputs_batch, last_stage_keys), + ] + + def __init__(self, config, **kwargs): + # NOTE: make _sequential_layers/_single_to_pp_mapping/_pp_to_single_mapping + # instance attrs instead of class attrs to support more than one pipeline + # models. Maybe make all sequential_layers add once. + self._sequential_layers = [] + self._single_to_pp_mapping = None + self._pp_to_single_mapping = None + # To be consistent with score model init and allow hyper-param be passed + # using __init__/from_pretrained + self._init_kwargs = kwargs + super().__init__(config) + + def add_head(self, config): + init_kwargs = self._init_kwargs + # hack to replace original RMSNormPipe to support ValueHead inputs + norm_prefix = self._sequential_layers.pop(-1)["name_prefix"] + self.add_sequential_layer(LayerDesc(_LlamaRMSNormPipe, config=config), norm_prefix) + self.add_sequential_layer(LayerDesc(ValueHead, config, **init_kwargs), "") + + def get_loss_fn(self, config): + return create_loss(RLHFValueLoss, config, self._init_kwargs) + + @property + def head_out_meta(self): + # None means to use actual data info + return ( + paddle.static.InputSpec(shape=[None, None, 1], dtype=None), + paddle.static.InputSpec(shape=[None, 1], dtype=None), + ) diff --git a/examples/RLHF/models/pp_model_utils.py b/examples/RLHF/models/pp_model_utils.py new file mode 100644 index 000000000000..1444cdbdd2e2 --- /dev/null +++ b/examples/RLHF/models/pp_model_utils.py @@ -0,0 +1,92 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect + +import paddle + + +def fwd_step_patch(func, output, self, *args, **kwargs): + # training patch + if self.training and self.is_pipeline_last_stage(): + if getattr(self, "_step_losses", None): + self._step_losses.append(output.detach()) + else: + self._step_losses = [output.detach()] + + +def make_wrapper(func, pre_patch=None, post_patch=None): + def wrapper(*args, **kwargs): + if pre_patch is not None: + pre_patch(func, None, *args, **kwargs) + output = func(*args, **kwargs) + if post_patch is not None: + post_patch(func, output, *args, **kwargs) + return output + + return wrapper + + +funcs = [(paddle.distributed.fleet.model.PipelineParallel._forward_step, fwd_step_patch)] + +for func in funcs: + fun, patch = func + module = importlib.import_module(fun.__module__) + cls_name = fun.__qualname__[: -len(fun.__name__) - 1] + wrap_fun = make_wrapper(fun, post_patch=patch) + cls_obj = getattr(module, cls_name) + setattr(cls_obj, fun.__name__, wrap_fun) + + +@paddle.no_grad() +def pad_batches_inputs(inputs, padding_value=0, max_len=None, pad_len=None): + """Pad length for tensors shaped [bs, seq_len] to [bs, max(seq_lens)]""" + if pad_len is not None: + pad_len = [pad_len] * len(inputs) if isinstance(pad_len, int) else pad_len + elif max_len is None: + # max_len = max([x.shape[-1] for x in inputs if x is not None]) + max_len = max([x.shape[-1] if isinstance(x, paddle.Tensor) else 0 for x in inputs]) + pad_len = [max_len - x.shape[-1] if isinstance(x, paddle.Tensor) else 0 for x in inputs] + for i in range(len(inputs)): + x = inputs[i] + # if x is None or x.shape[-1] == max_len: + if not isinstance(x, paddle.Tensor) or x.shape[-1] == max_len: + continue + inputs[i] = paddle.concat([x, paddle.full([x.shape[0], pad_len[i]], padding_value, dtype=x.dtype)], -1) + return inputs + + +def get_expected_keys(inputs, keys): + ret = tuple([inputs.get(k, None) for k in keys if k in inputs]) + if len(ret) == 1: + ret = ret[0] + return ret + + +def fwd_args_to_dict(fun): + def _impl(self, *args, **kwargs): + try: + return fun(self, *args, **kwargs) + except TypeError: + # otherwise, inputs is any valid format of non_pipe_model forward args, + # convert to dict, to support more args format in prediction_pipeline_step + # assume no arg is inspect.Parameter.VAR_KEYWORD + arg_dict = ( + inspect.signature(self._non_pipe_model_class.forward).bind(*((self,) + args), **kwargs).arguments + ) + arg_dict.pop("self") + return fun(self, arg_dict) + + return _impl diff --git a/examples/RLHF/models/ppo_model.py b/examples/RLHF/models/ppo_model.py new file mode 100644 index 000000000000..720009161022 --- /dev/null +++ b/examples/RLHF/models/ppo_model.py @@ -0,0 +1,117 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from paddlenlp.transformers import LlamaForCausalLM, PretrainedConfig + +from .ppo_model_utils import PolicyOutput, RLHFPPOMixedLoss, RLHFValueLoss, ValueOutput +from .score_model import LlamaModelForScore + + +# TODO(guosheng): create Mixin and make model classes using metaclass. +class LlamaPolicyModel(LlamaForCausalLM): + def __init__(self, config: PretrainedConfig, **kwargs): + super().__init__(config) + self.loss_fn = RLHFPPOMixedLoss(config, **kwargs) + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + labels=None, + use_cache=False, + past_key_values=None, + log_probs=None, + advantages=None, + sequence_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + outputs = super().forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + labels=None, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs[0] + loss = None + if labels is not None or advantages is not None: + loss = self.loss_fn(logits, (labels, input_ids, log_probs, advantages, sequence_mask)) + if not return_dict: + return (loss,) + outputs if loss is not None else outputs + + return PolicyOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class LlamaValueModel(LlamaModelForScore): + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.loss_fn = RLHFValueLoss(config, **kwargs) + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=False, + past_key_values=None, + old_values=None, + returns=None, + sequence_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + outputs = super().forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + reward_values, rewards = outputs + loss = None + if returns is not None: + loss = self.loss_fn(reward_values, old_values, returns, sequence_mask) + if not return_dict: + return (loss,) + outputs if loss is not None else outputs + + return ValueOutput( + loss=loss, + value=reward_values, + reward=rewards, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/examples/RLHF/models/ppo_model_utils.py b/examples/RLHF/models/ppo_model_utils.py new file mode 100644 index 000000000000..da8972cc5c6d --- /dev/null +++ b/examples/RLHF/models/ppo_model_utils.py @@ -0,0 +1,274 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2023 PKU-Alignment Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for score models.""" + +from __future__ import annotations + +import inspect +from dataclasses import dataclass +from typing import Optional, Tuple + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +# use LlamaPretrainingCriterion as common PretrainingCriterion +from paddlenlp.transformers import LlamaPretrainingCriterion as PretrainingCriterion +from paddlenlp.transformers.model_outputs import ModelOutput + + +@dataclass +class PolicyOutput(ModelOutput): + loss: Optional[paddle.Tensor] = None + logits: paddle.Tensor = None + # logits_entropy: Optional[paddle.Tensor] = None + past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]] = None + hidden_states: Optional[Tuple[paddle.Tensor]] = None + attentions: Optional[Tuple[paddle.Tensor]] = None + cross_attentions: Optional[Tuple[paddle.Tensor]] = None + + +@dataclass +class ValueOutput(ModelOutput): + loss: Optional[paddle.Tensor] = None + value: paddle.Tensor = None + reward: paddle.Tensor = None + past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]] = None + hidden_states: Optional[Tuple[paddle.Tensor]] = None + attentions: Optional[Tuple[paddle.Tensor]] = None + cross_attentions: Optional[Tuple[paddle.Tensor]] = None + + +def merge_fwd_labels(loss_cls): + """ + PipelineParallel and trainer.criterion both use labels as tuple, thus wrap. + """ + ori_fwd = loss_cls.forward + + def loss_fwd(self, predict, labels): + return ori_fwd(self, predict, *labels) + + fwd_params = inspect.signature(ori_fwd).parameters + # forward(self, predict, label1, label2, ...) + loss_cls.label_names = list(fwd_params.keys())[2:] + loss_cls.label_default_values = {} + for label_name in loss_cls.label_names: + if fwd_params[label_name].default is not inspect.Parameter.empty: + loss_cls.label_default_values[label_name] = fwd_params[label_name].default + loss_cls.forward = loss_fwd + return loss_cls + + +def create_loss(loss_cls, config, extra_args, merge_labels=None): + """ + loss_cls(paddle.nn.Layer): loss class + config(PratrainedConfig): model config, to be consistent with loss defined + in transformers + extra_args(dict): create loss with more args not in config + merge_labels: use a wrapped loss_cls whose label args are merged into one arg, + this is useful to PipelineParallel and trainer.criterion since they only + support loss format corresponding to this format. + """ + # TODO(guosheng): merge_labels if loss_cls not + ori_fwd = loss_cls.forward + if merge_labels: + fwd_params = inspect.signature(ori_fwd).parameters + if len(fwd_params.keys()) > 3: # merge_fwd_labels has not done + loss_cls = merge_fwd_labels(loss_cls) + # forward(self, predict, label1, label2, ...) + loss_arg_names = list(inspect.signature(loss_cls.__init__).parameters.keys())[2:] + if isinstance(extra_args, dict): + loss_kwargs = dict([(name, extra_args[name]) for name in loss_arg_names if name in extra_args]) + else: + # create from TrainingArguments + loss_kwargs = dict([(name, getattr(extra_args, name)) for name in loss_arg_names if hasattr(extra_args, name)]) + loss = loss_cls(config, **loss_kwargs) + return loss + + +@paddle.no_grad() +def make_position_ids(attention_mask, source=None): + if len(attention_mask.shape) == 4: # causal mask + position_ids_p1 = attention_mask.cast(paddle.int64).sum(-1) + position_ids = position_ids_p1 - 1 + position_ids = paddle.where(position_ids == -1, position_ids_p1, position_ids) + return position_ids[:, 0, :] + assert len(attention_mask.shape) == 2 # padding mask + attention_mask_bool = attention_mask + attention_mask = attention_mask.cast(paddle.int64) + position_ids = attention_mask.cumsum(-1) - 1 + # Make padding positions in source be 0, since reward model use position_ids + # plus with padding size (number of 0s) in source to calculate end offsets. + # It does not matter when source is left padding and target is right padding + # which is the output of non-FuseMT generation, while when using FuseMT whose + # output is right padding source and right padding target, we have to set + # padding positions in source be 0 to make compatible. + if source is not None: + src_len = position_ids[:, source.shape[-1] - 1].unsqueeze(-1) + position_ids = paddle.where( + paddle.logical_and(paddle.logical_not(attention_mask_bool), position_ids <= src_len), + attention_mask, + position_ids, + ) + return position_ids + position_ids = paddle.where(position_ids == -1, attention_mask, position_ids) + return position_ids + + +@paddle.no_grad() +def make_attention_mask(input_ids, pad_id, unk_id=None, past_key_values_length=0, causal_mask=True): + attention_mask = input_ids != pad_id + if unk_id is not None and pad_id != unk_id: + attention_mask = paddle.logical_and(attention_mask, input_ids != unk_id) + if not causal_mask: + return attention_mask + + batch_size, target_length = input_ids.shape # target_length: seq_len + mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool")) + if past_key_values_length > 0: + # [tgt_len, tgt_len + past_len] + mask = paddle.concat([paddle.ones([target_length, past_key_values_length], dtype="bool"), mask], axis=-1) + # [bs, 1, tgt_len, tgt_len + past_len] + causal_mask = mask[None, None, :, :].expand([batch_size, 1, target_length, target_length + past_key_values_length]) + + attention_mask = attention_mask[:, None, None, :] + expanded_attn_mask = attention_mask & causal_mask + return expanded_attn_mask + + +def gather_log_probabilities(logits: paddle.Tensor, labels: paddle.Tensor) -> paddle.Tensor: + """Gather log probabilities of the given labels from the logits.""" + log_probs = F.log_softmax(logits, axis=-1) + log_probs_labels = paddle.take_along_axis(log_probs, axis=-1, indices=labels.unsqueeze(axis=-1)) + return log_probs_labels.squeeze(axis=-1) + + +class RLHFPPOLoss(nn.Layer): + def __init__(self, config, clip_range_ratio=0.2): + super().__init__() + self.clip_range_ratio = clip_range_ratio + self.config = config + + def actor_loss_fn( + self, log_probs: paddle.Tensor, old_log_probs: paddle.Tensor, advantages: paddle.Tensor, mask: paddle.Tensor + ) -> paddle.Tensor: + # policy gradient loss + ratio = paddle.exp(log_probs - old_log_probs) + pg_loss1 = -advantages * ratio + pg_loss2 = -advantages * paddle.clip( + ratio, + 1.0 - self.clip_range_ratio, + 1.0 + self.clip_range_ratio, + ) + return paddle.sum(paddle.maximum(pg_loss1, pg_loss2) * mask) / mask.sum() + + def forward(self, logits, input_ids, old_log_probs, reward_advantages, sequence_mask): + log_probs = gather_log_probabilities(logits[:, :-1], input_ids[:, 1:]) + if log_probs.shape[1] == old_log_probs.shape[1]: + # labels (old_log_probs, reward_advantages, sequence_mask) has + # src+tgt-1 length, valid length is determined by sequence_mask + pass + elif log_probs.shape[1] < old_log_probs.shape[1]: + # labels (old_log_probs, reward_advantages, sequence_mask) has + # src+tgt length and the last one is a padding to be consistent + # with input_ids + assert log_probs.shape[1] == old_log_probs.shape[1] - 1 + log_probs = paddle.concat([log_probs, paddle.zeros([log_probs.shape[0], 1], dtype=log_probs.dtype)], -1) + else: + # labels (old_log_probs, reward_advantages, sequence_mask) has tgt length + log_probs = log_probs[:, -old_log_probs.shape[1] :] + actor_loss = self.actor_loss_fn( + log_probs, + old_log_probs, + reward_advantages, + sequence_mask, + ) + return actor_loss + + +@merge_fwd_labels +class RLHFPPOMixedLoss(nn.Layer): + """provide two losses, one for PPO loss, the other for SFT loss.""" + + def __init__(self, config, ptx_coeff=16, clip_range_ratio=0.2): + super(RLHFPPOMixedLoss, self).__init__() + self.ptx_coeff = ptx_coeff + self.ppo_criterion = RLHFPPOLoss(config, clip_range_ratio) + self.sft_criterion = PretrainingCriterion(config) + + def forward(self, logits, labels, input_ids, old_log_probs, reward_advantages, sequence_mask): + logits = logits if isinstance(logits, paddle.Tensor) else logits[0] + loss = None + # sft, pt loss + if labels is not None: + loss = self.ptx_coeff * self.sft_criterion(logits, labels) + # ppo loss + if reward_advantages is not None: + loss = self.ppo_criterion(logits, input_ids, old_log_probs, reward_advantages, sequence_mask) + + return loss + + +@merge_fwd_labels +class RLHFValueLoss(nn.Layer): + def __init__(self, config, clip_range_value=5.0): + super().__init__() + self.clip_range_value = clip_range_value + self.config = config + + def critic_loss_fn( + self, + values: paddle.Tensor, + old_values: paddle.Tensor, + returns: paddle.Tensor, + mask: paddle.Tensor, + ) -> paddle.Tensor: + """Compute critic loss.""" + # TODO(guosheng): use paddle.clip when its min/max can support more than + # 0D Tensor + values_clipped = paddle.minimum( + paddle.maximum(values, old_values - self.clip_range_value), old_values + self.clip_range_value + ) + vf_loss1 = paddle.square(values - returns) + vf_loss2 = paddle.square(values_clipped - returns) + return 0.5 * paddle.sum(paddle.maximum(vf_loss1, vf_loss2) * mask) / mask.sum() + + def forward(self, reward_values, old_reward_values, reward_returns, sequence_mask): + reward_values = reward_values if isinstance(reward_values, paddle.Tensor) else reward_values[0] + reward_values = reward_values.squeeze(axis=-1)[:, :-1] + if reward_values.shape[1] == old_reward_values.shape[1]: + # labels (old_reward_values, reward_returns, sequence_mask) has + # src+tgt-1 length, valid length is determined by sequence_mask + pass + elif reward_values.shape[1] < old_reward_values.shape[1]: + # labels (old_reward_values, reward_returns, sequence_mask) has + # src+tgt length and the last one is a padding to be consistent + # with input_ids + assert reward_values.shape[1] == old_reward_values.shape[1] - 1 + reward_values = paddle.concat( + [reward_values, paddle.zeros([reward_values.shape[0], 1], dtype=reward_values.dtype)], -1 + ) + else: + # labels (old_reward_values, reward_returns, sequence_mask) has + # tgt length + reward_values = reward_values[:, -old_reward_values.shape[1] :] + reward_critic_loss = self.critic_loss_fn( + reward_values, + old_reward_values, + reward_returns, + sequence_mask, + ) + + return reward_critic_loss diff --git a/examples/RLHF/models/score_model.py b/examples/RLHF/models/score_model.py index dd3741596a9e..aa0f50977945 100644 --- a/examples/RLHF/models/score_model.py +++ b/examples/RLHF/models/score_model.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import Any import paddle from paddle import nn +import paddlenlp from paddlenlp.transformers import ( LlamaConfig, LlamaModel, @@ -88,6 +91,7 @@ def forward( # pylint: disable=too-many-arguments return self.get_score( hidden_states, attention_mask=attention_mask, + position_ids=position_ids, return_dict=return_dict, ) @@ -131,3 +135,6 @@ def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]: mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] return mappings + + +paddlenlp.transformers.LlamaModelForScore = LlamaModelForScore diff --git a/examples/RLHF/models/score_model_utils.py b/examples/RLHF/models/score_model_utils.py index 5d14f7995731..5515d56fbc20 100644 --- a/examples/RLHF/models/score_model_utils.py +++ b/examples/RLHF/models/score_model_utils.py @@ -49,9 +49,10 @@ class AutoModelForScore(_BaseAutoModelClass): _score_module_name: str = "models.score_model" @classmethod - def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file_path): - with io.open(config_file_path, encoding="utf-8") as f: - config = json.load(f) + def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file_path, config=None): + if config is None: + with io.open(config_file_path, encoding="utf-8") as f: + config = json.load(f) # Get class name corresponds to this configuration if is_standard_config(config): @@ -167,25 +168,53 @@ def init_score_head(self, config: PretrainedConfig, hidden_size: int, **kwargs: def get_score( self, hidden_state: paddle.Tensor, # size = (B, L, E) - attention_mask: paddle.Tensor, # size = (B, L) + attention_mask: paddle.Tensor | None = None, # size = (B, L) + position_ids: paddle.Tensor | None = None, # size = (B, L) return_dict: bool | None = None, ) -> ScoreModelOutput: """Forward pass of the score model.""" scores = self.score_head(hidden_state) # size = (B, L, D) - end_score = [] - for i in range(hidden_state.shape[0]): - end_index = attention_mask[i].nonzero()[-1].item() - end_score.append(scores[i, end_index]) # size = (D,) - end_score = paddle.stack(end_score, axis=0) # size = (B, D) - - if self.training: + if position_ids is not None: + first_pos = paddle.arange(hidden_state.shape[0]).unsqueeze(-1) + # Take left padding into account, which has 0s in left and max_len + # in right. + left_pad_mask = position_ids == 0 + # position_ids = paddle.where( + # left_pad_mask, position_ids, position_ids + left_pad_mask.sum(-1, keepdim=True) - 1 + # ) + # the above limits right padding must not be 0s, the following suits + # to both left and right padding with 0s + left_pad_num = ( + paddle.where(left_pad_mask, position_ids.shape[-1] + 100, position_ids).argmin(axis=-1, keepdim=True) + - 1 + ) + position_ids = left_pad_num + position_ids + second_pos = paddle.max(position_ids, axis=-1, keepdim=True) + end_pos = paddle.stack([first_pos, second_pos], axis=-1).squeeze(1) + end_score = scores.gather_nd(end_pos) + else: + # attention_mask passed from pipeline pre-stage is shaped (bs, 1, seq_len, seq_len) + assert attention_mask is not None and len(attention_mask.shape) == 2 + end_score = [] + end_pos = [] + for i in range(hidden_state.shape[0]): + end_index = attention_mask[i].nonzero()[-1].item() + end_pos.append((i, end_index)) + end_score.append(scores[i, end_index]) # size = (D,) + end_score = paddle.stack(end_score, axis=0) # size = (B, D) + + if self.training and self.do_normalize: if dist.is_initialized(): - # TODO(guosheng): maybe only need nodes in data parallel group - # when support hybird dist parallel. - gathered_end_score_list = [paddle.zeros_like(end_score) for _ in range(dist.get_world_size())] - dist.all_gather(gathered_end_score_list, end_score) + gathered_end_score_list = [] + try: + # gather among data parallel group + hcg = dist.fleet.get_hybrid_communicate_group() + group = hcg.get_sharding_parallel_group() + dist.all_gather(gathered_end_score_list, end_score, group) + except: + dist.all_gather(gathered_end_score_list, end_score) gathered_end_score = paddle.concat(gathered_end_score_list, axis=0) self.normalizer.update(gathered_end_score) else: diff --git a/examples/RLHF/ppo_config.json b/examples/RLHF/ppo_config.json index d15331443608..7bc5f88e515f 100644 --- a/examples/RLHF/ppo_config.json +++ b/examples/RLHF/ppo_config.json @@ -4,8 +4,9 @@ "ptx_datasets": "alpaca", "actor_model_name_or_path": "PKU-Alignment/alpaca-7b-reproduced", "reward_model_name_or_path": "PKU-Alignment/beaver-7b-v1.0-reward", - "output_dir": "./checkpoints/ppo", + "output_dir": "/root/paddlejob/workspace/guosheng/ckpts/ppo-reshard-sd38", "max_length": 512, + "top_p": 0.8, "temperature": 1.0, "num_return_sequences":1, "repetition_penalty": 1.0, @@ -34,7 +35,7 @@ "logging_steps": 1, "evaluation_strategy": "steps", "eval_steps": 100, - "save_strategy": "steps", + "save_strategy": "epoch", "save_steps": 100000, "bf16": true, "fp16_opt_level": "O2", @@ -43,9 +44,14 @@ "disable_tqdm": true, "save_total_limit": 1, "sharding_parallel_degree": 4, - "sharding": "stage3", - "comment-PKU_Beaver-max_grad_norm": 1.0, + "sharding": "stage1", + "tensor_parallel_degree": 2, + "pipeline_parallel_degree": 1, + "pipeline_parallel_config": "disable_p2p_cache_shape", "max_grad_norm": 1.0, "adam_beta1": 0.9, - "adam_beta2": 0.95 + "adam_beta2": 0.95, + "dataloader_drop_last": false, + "eval_mode": "", + "offload_level": "freeze_model" } diff --git a/examples/RLHF/ppo_main.py b/examples/RLHF/ppo_main.py index f517288a8679..d52b30b95f90 100644 --- a/examples/RLHF/ppo_main.py +++ b/examples/RLHF/ppo_main.py @@ -12,15 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import os import sys +import types + +# os.environ["http_proxy"] = "http://10.162.37.16:8128" +# os.environ["https_proxy"] = "http://10.162.37.16:8128" +# os.environ["no_proxy"] = "localhost,bcebos.com" +# launch would unset http_proxy +# export https_proxy=http://172.19.57.45:3128 + +# os.environ["http_proxy"] = "http://172.19.56.199:3128" +# os.environ["https_proxy"] = "http://172.19.56.199:3128" + +# os.environ["http_proxy"] = "http://172.19.57.45:3128" +# os.environ["https_proxy"] = "http://172.19.57.45:3128" + +os.environ["http_proxy"] = "http://10.162.37.16:8128" +os.environ["https_proxy"] = "http://10.162.37.16:8128" +os.environ["no_proxy"] = "localhost,bcebos.com" + +# os.environ["http_proxy"] = "agent.baidu.com:8118" +# os.environ["https_proxy"] = "agent.baidu.com:8118" + from dataclasses import dataclass, field +from functools import partial from typing import Any, Dict, Tuple import paddle from data import PromptOnlyDataset, SupervisedDataset, parse_dataset from models import AutoModelForScore -from ppo_trainer import PPOTrainer +from models.score_model import LlamaModelForScore # noqa +from ppo_trainer import PPOTrainer, cleanup_tensor_space, offload_tensor_to_cpu from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint from paddlenlp.transformers import ( @@ -31,6 +55,7 @@ ) from paddlenlp.utils.log import logger + @dataclass class TrainingArguments(TrainingArguments): kl_coeff: float = field( @@ -59,8 +84,8 @@ class TrainingArguments(TrainingArguments): default=0.0, metadata={"help": "The coefficient for the ptx loss."}, ) - update_iters: float = field( - default=0.0, + update_iters: int = field( + default=1, metadata={"help": "The number of repeated updates on a generated batch."}, ) critic_learning_rate: float = field( @@ -91,12 +116,12 @@ class TrainingArguments(TrainingArguments): default=1.0, metadata={"help": "The value used to module the next token probabilities."}, ) - top_k: int = field( - default=1, - metadata={"help": "top_k"}, - ) + # top_k: int = field( + # default=1, + # metadata={"help": "top_k"}, + # ) top_p: float = field( - default=1.0, + default=0.8, metadata={ "help": "If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to`top_p` or higher are kept for generation." }, @@ -113,6 +138,22 @@ class TrainingArguments(TrainingArguments): default=16, metadata={"help": "Batch size (per device) for the training dataloader."}, ) + eval_mode: str = field( + default=None, + metadata={ + "help": "eval mode for actor model and reward_critic_model, optional for: None, single, tensor_parallel." + }, + ) + + offload_level: str = field( + default="", + metadata={"help": "Offload model, optional for: eval, reward, optimizer, train_model"}, + ) + use_fusemt: bool = field( + default=True, + metadata={"help": "use inference model to speedup in rollout generation"}, + ) + # save_generation_output: bool = field( # default=False, # metadata={"help": "Whether to save generated text to file when eval"}, @@ -184,6 +225,10 @@ def main(): model_args, data_args, training_args = parser.parse_args_into_dataclasses() training_args.print_config(model_args, "Model") training_args.print_config(data_args, "Data") + if training_args.eval_mode is not None and len(training_args.eval_mode) == 0: + training_args.eval_mode = None + if training_args.eval_mode is None and training_args.offload_level is not None: + training_args.offload_level = training_args.offload_level.replace("eval", "") # Setup GPU & distributed training paddle.set_device(training_args.device) @@ -219,60 +264,206 @@ def main(): dtype = "float32" training_args.max_length = data_args.max_length + model_class_lm, model_class_score = AutoModelForCausalLM, AutoModelForScore if training_args.pipeline_parallel_degree > 1: - raise ValueError("Not support pipeline parallel mode.") + from models.model_pp import LlamaPolicyPipe, LlamaValuePipe + + model_class_lm = LlamaPolicyPipe + model_class_score = LlamaValuePipe + extra_args = { + "ptx_coeff": training_args.ptx_coeff, + "clip_range_ratio": training_args.clip_range_ratio, + } else: - # actor model - model_config = AutoConfig.from_pretrained( - model_args.actor_model_name_or_path, - tensor_parallel_output=False, - tensor_parallel_degree=training_args.tensor_parallel_degree, - tensor_parallel_rank=training_args.tensor_parallel_rank, - dtype=dtype, - ) - if hasattr(model_config, "use_flash_attention"): - model_config.use_flash_attention = model_args.use_flash_attention - actor_model = AutoModelForCausalLM.from_pretrained( + # non-pipe modelForCausalLM does not accept extra_args and use other ways + # (StepTrainer.create_criterion) to set hyper-parameters + extra_args = {} + + # actor model + model_config = AutoConfig.from_pretrained( + model_args.actor_model_name_or_path, + tensor_parallel_output=False, + tensor_parallel_degree=training_args.tensor_parallel_degree, + tensor_parallel_rank=training_args.tensor_parallel_rank, + dtype=dtype, + ) + if hasattr(model_config, "use_flash_attention"): + model_config.use_flash_attention = model_args.use_flash_attention + + # model_config.num_hidden_layers = 2 + + actor_model = model_class_lm.from_pretrained( + model_args.actor_model_name_or_path, + config=model_config, + **extra_args, + # ptx_coeff=training_args.ptx_coeff, + # clip_range_ratio=training_args.clip_range_ratio, + ) + if training_args.eval_mode is not None: + config = copy.deepcopy(actor_model.config) + if training_args.eval_mode == "single": + config.tensor_parallel_degree = -1 + config.tensor_parallel_rank = 0 + actor_eval_model = AutoModelForCausalLM.from_config(config) + # TODO(guosheng): AutoModel (in `_get_model_class_from_config`) pop out + # architecture which is necessary for infer predictor currently + config.architectures = actor_model.config.architectures + # actor_eval_model = AutoModelForCausalLM.from_pretrained(model_args.actor_model_name_or_path, config=config) + else: + actor_eval_model = None + + # todo reference model + if training_args.eval_mode is not None: + config = copy.deepcopy(model_config) + if training_args.eval_mode == "single": + config.tensor_parallel_degree = -1 + config.tensor_parallel_rank = 0 + actor_reference_model = AutoModelForCausalLM.from_pretrained( model_args.actor_model_name_or_path, - config=model_config, + config=config, ) - # reference model - actor_reference_model = AutoModelForCausalLM.from_pretrained( + else: + actor_reference_model = model_class_lm.from_pretrained( model_args.actor_model_name_or_path, config=model_config, ) - actor_tokenizer = AutoTokenizer.from_pretrained( - model_args.actor_model_name_or_path, model_max_length=data_args.max_length, padding_side="left" - ) - # reward model - model_config = AutoConfig.from_pretrained( + actor_tokenizer = AutoTokenizer.from_pretrained( + model_args.actor_model_name_or_path, model_max_length=data_args.max_length, padding_side="left" + ) + + # reward model + model_config = AutoConfig.from_pretrained( + model_args.reward_model_name_or_path, + tensor_parallel_output=False, + tensor_parallel_degree=training_args.tensor_parallel_degree, + tensor_parallel_rank=training_args.tensor_parallel_rank, + dtype=dtype, + ) + if hasattr(model_config, "use_flash_attention"): + model_config.use_flash_attention = model_args.use_flash_attention + # model_config.num_hidden_layers = 2 + # todo + if training_args.eval_mode is not None: + config = copy.deepcopy(model_config) + if training_args.eval_mode == "single": + config.tensor_parallel_degree = -1 + config.tensor_parallel_rank = 0 + reward_model = AutoModelForScore.from_pretrained( model_args.reward_model_name_or_path, - tensor_parallel_output=False, - tensor_parallel_degree=training_args.tensor_parallel_degree, - tensor_parallel_rank=training_args.tensor_parallel_rank, - dtype=dtype, + config=config, + score_type="reward", + do_normalize=training_args.normalize_reward, ) - if hasattr(model_config, "use_flash_attention"): - model_config.use_flash_attention = model_args.use_flash_attention - reward_model = AutoModelForScore.from_pretrained( + else: + reward_model = model_class_score.from_pretrained( model_args.reward_model_name_or_path, config=model_config, score_type="reward", do_normalize=training_args.normalize_reward, ) - reward_tokenizer = AutoTokenizer.from_pretrained( - model_args.reward_model_name_or_path, model_max_length=data_args.max_length, padding_side="right" - ) - # critic model - if model_args.reward_critic_model_name_or_path is None: - model_args.reward_critic_model_name_or_path = model_args.reward_model_name_or_path - reward_critic_model = AutoModelForScore.from_pretrained( - model_args.reward_critic_model_name_or_path, config=model_config, score_type="critic", do_normalize=False - ) - reward_critic_tokenizer = AutoTokenizer.from_pretrained( - model_args.reward_critic_model_name_or_path, model_max_length=data_args.max_length, padding_side="left" - ) + reward_tokenizer = AutoTokenizer.from_pretrained( + model_args.reward_model_name_or_path, model_max_length=data_args.max_length, padding_side="right" + ) + # critic model + if model_args.reward_critic_model_name_or_path is None: + model_args.reward_critic_model_name_or_path = model_args.reward_model_name_or_path + reward_critic_model = model_class_score.from_pretrained( + model_args.reward_critic_model_name_or_path, + config=model_config, + score_type="critic", + do_normalize=False, + clip_range_value=training_args.clip_range_value, + ) + reward_critic_tokenizer = AutoTokenizer.from_pretrained( + model_args.reward_critic_model_name_or_path, model_max_length=data_args.max_length, padding_side="left" + ) + if training_args.eval_mode is not None: + config = copy.deepcopy(reward_critic_model.config) + if training_args.eval_mode == "single": + config.tensor_parallel_degree = -1 + config.tensor_parallel_rank = 0 + reward_critic_eval_model = AutoModelForScore.from_config(config) + # reward_critic_eval_model = AutoModelForScore.from_pretrained( + # model_args.reward_critic_model_name_or_path,config=model_config + # ) + else: + reward_critic_eval_model = None + + # # actor model + # model_config = AutoConfig.from_pretrained( + # model_args.actor_model_name_or_path, + # tensor_parallel_output=False, + # tensor_parallel_degree=training_args.tensor_parallel_degree, + # tensor_parallel_rank=training_args.tensor_parallel_rank, + # dtype=dtype, + # ) + # model_config.num_hidden_layers = 2 + # if hasattr(model_config, "use_flash_attention"): + # model_config.use_flash_attention = model_args.use_flash_attention + # actor_model = AutoModelForCausalLM.from_pretrained( + # model_args.actor_model_name_or_path, + # config=model_config, + # ) + # + # if training_args.eval_mode is not None: + # config = copy.deepcopy(actor_model.config) + # if training_args.eval_mode == "single": + # config.tensor_parallel_degree = -1 + # config.tensor_parallel_rank = 0 + # actor_eval_model = AutoModelForCausalLM.from_config(config) + # else: + # actor_eval_model = None + # + # # reference model + # actor_reference_model = AutoModelForCausalLM.from_pretrained( + # model_args.actor_model_name_or_path, + # config=model_config, + # ) + # actor_tokenizer = AutoTokenizer.from_pretrained( + # model_args.actor_model_name_or_path, model_max_length=data_args.max_length, padding_side="left" + # ) + # + # # reward model + # model_config = AutoConfig.from_pretrained( + # model_args.reward_model_name_or_path, + # tensor_parallel_output=False, + # tensor_parallel_degree=training_args.tensor_parallel_degree, + # tensor_parallel_rank=training_args.tensor_parallel_rank, + # dtype=dtype, + # ) + # model_config.num_hidden_layers = 2 + # if hasattr(model_config, "use_flash_attention"): + # model_config.use_flash_attention = model_args.use_flash_attention + # reward_model = AutoModelForScore.from_pretrained( + # model_args.reward_model_name_or_path, + # config=model_config, + # score_type="reward", + # do_normalize=training_args.normalize_reward, + # ) + # reward_tokenizer = AutoTokenizer.from_pretrained( + # model_args.reward_model_name_or_path, model_max_length=data_args.max_length, padding_side="right" + # ) + # + # # critic model + # if model_args.reward_critic_model_name_or_path is None: + # model_args.reward_critic_model_name_or_path = model_args.reward_model_name_or_path + # reward_critic_model = AutoModelForScore.from_pretrained( + # model_args.reward_critic_model_name_or_path, config=model_config, score_type="critic", do_normalize=False + # ) + # reward_critic_tokenizer = AutoTokenizer.from_pretrained( + # model_args.reward_critic_model_name_or_path, model_max_length=data_args.max_length, padding_side="left" + # ) + # + # if training_args.eval_mode is not None: + # config = copy.deepcopy(reward_critic_model.config) + # if training_args.eval_mode == "single": + # config.tensor_parallel_degree = -1 + # config.tensor_parallel_rank = 0 + # reward_critic_eval_model = AutoModelForScore.from_config(config) + # else: + # reward_critic_eval_model = None + for tokenizer in [actor_tokenizer, reward_tokenizer, reward_critic_tokenizer]: if isinstance(tokenizer, LlamaTokenizer) and tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id @@ -290,9 +481,37 @@ def main(): if data_args.ptx_datasets is not None else None ) + if ptx_ds is not None: + # PretrainingCriterion requires shifted inputs and labels + ptx_ds.get_collator = types.MethodType(partial(ptx_ds.get_collator.__func__, shift=True), ptx_ds) + + # offload + # cleanup actor_eval_model, reward_critic_eval_model + # offload actor_reference_model reward_model + + if training_args.offload_level is not None: + if "eval" in training_args.offload_level: + cleanup_tensor_space(actor_eval_model.state_dict()) + cleanup_tensor_space(reward_critic_eval_model.state_dict()) + if "reward" in training_args.offload_level: + # if pp mode, should lazy offload + offload_tensor_to_cpu(actor_reference_model.state_dict()) + offload_tensor_to_cpu(reward_model.state_dict()) trainer = PPOTrainer( - model=(actor_model, actor_reference_model, reward_model, reward_critic_model), + # (policy_model, reference_model, reward_model, value_model) + # policy_model, sft_model, reward_model, value_model + # (policy_model, reference_model, reward_model, value_model, + # (policy_model, reference_model, reward_model, value_model, policy_eval_model, value_eval_model + # (actor_model, actor_reference_model, reward_model, reward_critic_model, actor_eval_model, reward_critic_eval_model + model=( + actor_model, + actor_reference_model, + reward_model, + reward_critic_model, + actor_eval_model, + reward_critic_eval_model, + ), args=training_args, train_dataset=train_ds, eval_dataset=dev_ds, diff --git a/examples/RLHF/ppo_trainer.py b/examples/RLHF/ppo_trainer.py index 84577d6d2c01..c2c72d6c5cd1 100644 --- a/examples/RLHF/ppo_trainer.py +++ b/examples/RLHF/ppo_trainer.py @@ -16,471 +16,73 @@ import itertools import math import os +import sys import time -from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np import paddle +import paddle.distributed as dist import paddle.nn as nn -import paddle.nn.functional as F -import tqdm -from data import DummyDataset, PromptOnlyBatch +from comm_utils import ( # noqa + cleanup_tensor_space, + create_data_trans_group, + data_group_merge, + data_group_split, + offload_tensor_to_cpu, + reload_tensor_to_gpu, +) +from infer_utils import InferEvalModel, infer_guard +from models.ppo_model_utils import ( + RLHFPPOMixedLoss, + RLHFValueLoss, + create_loss, + gather_log_probabilities, + make_attention_mask, + make_position_ids, +) +from paddle.distributed import fleet from paddle.io import DataLoader, Dataset, DistributedBatchSampler from paddle.utils import map_structure from rich.console import Console from rich.table import Table +from trainer_utils import ( + MuteDefaultFlowCallback, + PipeEvalModel, + batch_retokenize, + guard_set_args, + is_same_tokenizer, +) from paddlenlp.data import DataCollator from paddlenlp.generation import GenerationConfig from paddlenlp.trainer.trainer import ( - TRAINER_STATE_NAME, EvalLoopOutput, EvalPrediction, - HybridParallelOptimizer, - NlpDistributedBatchSampler, ShardingOption, Trainer, TrainerCallback, - TrainerControl, - TrainerState, TrainingArguments, - _obtain_optimizer_parameters_list, - distributed_file, - distributed_isfile, - fused_allreduce_gradients, logger, - reshard_util, speed_metrics, - split_inputs_sequence_dim, -) -from paddlenlp.transformers import BatchEncoding, PretrainedModel, PretrainedTokenizer -from paddlenlp.transformers.tokenizer_utils_base import ( - PaddingStrategy, - TruncationStrategy, ) +from paddlenlp.transformers import PretrainedModel, PretrainedTokenizer -def batch_retokenize( - input_ids: paddle.Tensor, - src_tokenizer: PretrainedTokenizer, - dest_tokenizer: PretrainedTokenizer, - *, - padding: bool | str | PaddingStrategy = PaddingStrategy.LONGEST, - truncation: bool | str | TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, - skip_special_tokens: bool = True, -) -> BatchEncoding: - """Re-tokenize a batch of input ids from one tokenizer to another.""" - output = dest_tokenizer( - [ - text + dest_tokenizer.eos_token - for text in src_tokenizer.batch_decode( - input_ids, - skip_special_tokens=skip_special_tokens, - ) - ], - padding=padding, - truncation=truncation, - return_tensors="pd", - ) - return output - - -def gather_log_probabilities(logits: paddle.Tensor, labels: paddle.Tensor) -> paddle.Tensor: - """Gather log probabilities of the given labels from the logits.""" - log_probs = F.log_softmax(logits, axis=-1) - log_probs_labels = paddle.take_along_axis(log_probs, axis=-1, indices=labels.unsqueeze(axis=-1)) - return log_probs_labels.squeeze(axis=-1) - - -def init_train_model_opt( - self: Trainer, max_steps: int, resume_from_checkpoint: bool = False, clear_master_weight: bool = False -) -> PretrainedModel: - # Copy of model/optimizer init and resuming related code in `Trainer.train`. - # NOTE: this `_load_from_checkpoint` is indeed to load model states in the - # following elif-else branches, though they are apart away in `Trainer.train`. - if not self.args.should_load_sharding_stage1_model: - self._load_from_checkpoint(resume_from_checkpoint) - - # delay_optimizer_creation = ( - # self.sharding is not None - # and ShardingOption.SHARD_OP in self.args.sharding - # ) - delay_optimizer_creation = False - - if not delay_optimizer_creation: - self.create_optimizer_and_scheduler(num_training_steps=max_steps) - - if self.args.should_load_sharding_stage1_model: - model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint) - elif self.args.should_save_sharding_stage1_model: - # In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model. - # In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks. - model = self._wrap_model(self.model_wrapped) - if self.sharding_io is not None: - assert delay_optimizer_creation is False, "delay_optimizer_creation should be False" - # the self.optimizer should be wrapped and it is done in _wrap_model - self.sharding_io.set_optimizer(self.optimizer) - # for the rest of this function `model` is the outside model, whether it was wrapped or not - if model is not self.model: - self.model_wrapped = model - if delay_optimizer_creation: - self.create_optimizer_and_scheduler(num_training_steps=max_steps) - self._load_optimizer_and_scheduler(resume_from_checkpoint) - else: - model = self._wrap_model(self.model_wrapped) - # for the rest of this function `model` is the outside model, whether it was wrapped or not - if model is not self.model: - self.model_wrapped = model - if delay_optimizer_creation: - self.create_optimizer_and_scheduler(num_training_steps=max_steps) - self._load_optimizer_and_scheduler(resume_from_checkpoint) - - if ShardingOption.FULL_SHARD in self.args.sharding and clear_master_weight: - # for inference model to use Trainer sharding stage3, clear master_weight - # which is created in GroupShardedStage3.__init__ - self.optimizer._master_weights = None - - if self.args.device == "npu" and self.args.flatten_param_grads: - from .plugins.npu_plugin import npu_accelerate_plugin - - npu_accelerate_plugin(self.optimizer) - - return model - - -def init_train_state( - self: Trainer, - resume_from_checkpoint: bool, - train_dataloader: DataLoader, - max_steps: int, - num_train_epochs: int, - num_update_steps_per_epoch: int, -): - args = self.args - - self.state = TrainerState() - self.state.epoch = 0 - epochs_trained = 0 - steps_trained_in_current_epoch = 0 - steps_trained_progress_bar = None - - # Check if continuing training from a checkpoint - if resume_from_checkpoint is not None and distributed_isfile( - os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) - ): - self.state = TrainerState.load_from_json( - distributed_file(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) - ) - epochs_trained = self.state.global_step // num_update_steps_per_epoch - if not args.ignore_data_skip: - steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) - steps_trained_in_current_epoch *= args.gradient_accumulation_steps - else: - steps_trained_in_current_epoch = 0 - - logger.info(" Continuing training from checkpoint, will skip to saved global_step") - logger.info(f" Continuing training from epoch {epochs_trained}") - logger.info(f" Continuing training from global step {self.state.global_step}") - if not args.ignore_data_skip: - logger.info( - f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " - "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` " - "flag to your launch command, but you will resume the training on data already seen by your model." - ) - if self.is_local_process_zero() and not args.disable_tqdm: - steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) - steps_trained_progress_bar.set_description("Skipping the first batches") - if not args.ignore_data_skip: - if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance( - train_dataloader.batch_sampler, NlpDistributedBatchSampler - ): - consumed_samples = ( - self.state.global_step - * args.train_batch_size - * args.gradient_accumulation_steps - * args.dataset_world_size - ) - train_dataloader.batch_sampler.set_epoch(consumed_samples=consumed_samples) - logger.info(f"Set DistributedBatchSampler consumed_samples to {consumed_samples}") - - self.state.max_steps = int(max_steps) - self.state.num_train_epochs = num_train_epochs - self.state.is_local_process_zero = self.is_local_process_zero() - self.state.is_world_process_zero = self.is_world_process_zero() - - return epochs_trained, steps_trained_in_current_epoch, steps_trained_progress_bar - - -def init_train_log( - self: Trainer, - num_examples: int, - num_train_epochs: int, - total_train_batch_size: int, - max_steps: int, - num_train_samples: int, - model: PretrainedModel, -): - args = self.args - - logger.info("***** Running training *****") - logger.info(f" Num examples = {num_examples:,}") - logger.info(f" Num Epochs = {num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {max_steps:,}") - logger.info(f" Total num train samples = {num_train_samples:,}") - # per_device_trainable_numel = sum(p.numel().item() for p in model.parameters() if not p.stop_gradient) - # TODO: Temporary fix since Tensor.numel() not supported in distributed mode - per_device_trainable_numel = sum(np.prod(p.shape) for p in model.parameters() if not p.stop_gradient) - logger.info(f" Number of trainable parameters = {per_device_trainable_numel:,} (per device)") - if self.args.use_hybrid_parallel: - # todo fix for pipeline_parallel_degree - parts_num = max(self.args.tensor_parallel_degree, 1) * max(self.args.pipeline_parallel_degree, 1) - if parts_num > 1: - all_reduce_dtype = "int64" - if paddle.get_device().split(":")[0] in ["npu", "xpu"]: - # TODO(duanyanhui): fix when NPU all_reduce supports int64 - all_reduce_dtype = "float32" - trainable_numel_tensor = paddle.to_tensor(per_device_trainable_numel, dtype=all_reduce_dtype) - paddle.distributed.all_reduce(trainable_numel_tensor) - trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size - # the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited - # so, the trainable numel is a little bigger than real. - logger.info(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)") - - -def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs): +class StepTrainer(Trainer): """ - Just a copy of single training step complete code in Trainer.train while loop - which including forward+backward+step, while wraps the inputs and outputs to - make the complicated copied code no need to change. Maybe a better way is to - add fine-grained methods including these steps to Trainer which is similar to - DeepSpeed engine. + Features of StepTrainer: + 1. Trainer enhanced with step-level training combining with patches of + Trianer. We can use this to do training whose step is composed of multi + models via multiple instances of StepTrainer, such as PPO. + 2. Additionally, using a mixed loss and get the separated loss metrics is + supported, which is helpful to PipelienParallel with a mixed loss. + 3. EMA is supported. """ - # TODO(guosheng): step, steps_trained_in_current_epoch and steps_trained_progress_bar - # should use reference since they would be overwrite. - # for state update - epoch = kwargs.get("epoch", 0) - step = kwargs.get("step", 0) - steps_in_epoch = kwargs.get("steps_in_epoch", 0) - step_control = kwargs.get("step_control", 0) - # for step and progress update when resuming data - train_dataloader = kwargs.get("train_dataloader", None) - resume_from_checkpoint = kwargs.get("resume_from_checkpoint", None) - steps_trained_in_current_epoch = kwargs.get("steps_trained_in_current_epoch", 0) - steps_trained_progress_bar = kwargs.get("steps_trained_progress_bar", None) - # for eval output ignore to gather - ignore_keys_for_eval = kwargs.get("ignore_keys_for_eval", None) - tr_loss = kwargs.get("tr_loss", 0.0) - model = kwargs.get("model", self.model_wrapped) - - args = self.args - - if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1: - inputs = split_inputs_sequence_dim(inputs) - self.timers and self.timers("read-data").stop() - os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step) - self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs) - - # Skip past any already trained steps if resuming training - # for paddlenlp.utils.batch_sampler.DistributedBatchSampler - # We use consumed_samples to reset the status - if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance( - train_dataloader.batch_sampler, NlpDistributedBatchSampler - ): - if step == 0: - if steps_trained_progress_bar is not None: - steps_trained_progress_bar.update(steps_trained_in_current_epoch) - steps_trained_progress_bar.close() - steps_trained_progress_bar = None - self._load_rng_state(resume_from_checkpoint) - step += steps_trained_in_current_epoch - elif steps_trained_in_current_epoch > 0: - steps_trained_in_current_epoch -= 1 - if steps_trained_progress_bar is not None: - steps_trained_progress_bar.update(1) - if steps_trained_in_current_epoch == 0: - self._load_rng_state(resume_from_checkpoint) - # continue - final_local_vars = locals() - for k in kwargs.keys(): - if k in final_local_vars: - kwargs[k] = final_local_vars[k] - return kwargs - elif steps_trained_progress_bar is not None: - steps_trained_progress_bar.close() - steps_trained_progress_bar = None - - if step_control % args.gradient_accumulation_steps == 0: - self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - self.timers and self.timers("forward-backward").start() - - dp_enabled = self.args.data_parallel_degree > 1 if self.args.use_hybrid_parallel else args.local_rank != -1 - forbidden_no_sync = False - # stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API - # hybrid_parallel (tp or pp or sharding stage 1) should not no_sync - if self.args.use_hybrid_parallel: - forbidden_no_sync = True - - availiable_no_sync = dp_enabled and not forbidden_no_sync - - is_no_sync = ( - ((step_control + 1) % args.gradient_accumulation_steps != 0) - and availiable_no_sync - and args._no_sync_in_gradient_accumulation - ) or (args.recompute and availiable_no_sync) - # sharding - # stage1. the same as ddp - # stage2. manualy collect gradient on dp group - - dp_master_grad = self.args.world_size > 1 and self.args.amp_master_grad and not self.args.use_hybrid_parallel - if dp_master_grad: - is_no_sync = True - - if is_no_sync: - # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. - with model.no_sync(): - tr_loss_step = self.training_step(model, inputs) - else: - tr_loss_step = self.training_step(model, inputs) - - tr_loss += tr_loss_step - - if (step_control + 1) % args.gradient_accumulation_steps == 0 or ( - # last step in epoch but step is always smaller than gradient_accumulation_steps - steps_in_epoch <= args.gradient_accumulation_steps - and (step + 1) == steps_in_epoch - ): - if self.args.pipeline_parallel_degree <= 1 and self._enable_delay_scale_loss(): - tr_loss /= self.args.gradient_accumulation_steps - - self.timers and self.timers("forward-backward").stop() - # Maunally collect gradients - # Case 1: Use recompute and dp - # Case 2: Hack dp with master_grad - # Case 3: Pipeline or sharding overlap - # local_rank != -1 don't means dp in networks. - self.timers and self.timers("all-reduce").start() - - # Case 1: Use recompute and dp / sharding stage1, - # manualy collect gradient for dp. - if args.recompute and availiable_no_sync: - fused_allreduce_gradients(list(model.parameters()), None) - - # Case 2: hack dp with master_grad - if dp_master_grad and not (args.recompute and availiable_no_sync): - fused_allreduce_gradients(list(model.parameters()), None) - - # Pipeline parallel mode, handle gradient reduce here to overlap - pipeline_parallel_config = ( - set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set() - ) - enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config - enable_release_grads = "enable_release_grads" in pipeline_parallel_config - - # Case 3: Pipeline parallel mode, overlap with dp - if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling: - parameters_list = _obtain_optimizer_parameters_list(self.optimizer._inner_opt) - - if not enable_dp_comm_overlap: - if self.optimizer._sharding_enable: - assert reshard_util.is_sharding_opt(self.optimizer) - self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg) - - if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False): - fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg) - - self.timers and self.timers("all-reduce").stop() - self.timers and self.timers("optimizer-step").start() - - if self.args.gradient_accumulation_steps > 1 and self._enable_delay_scale_loss(): - for p in model._layers.parameters(): - with paddle.no_grad(): - if hasattr(p, "main_grad") and p.main_grad is not None: - assert p.grad is None - p.main_grad.scale_(1.0 / self.args.gradient_accumulation_steps) - elif p.grad is not None: - p.grad.scale_(1.0 / self.args.gradient_accumulation_steps) - - # Optimizer step - self.callback_handler.on_optimizer_begin( - args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None - ) - optimizer_was_run = True - if self.do_grad_scaling: - scale_before = paddle.assign(self.scaler._scale) - self.scaler.step(self.optimizer) - self.scaler.update() - scale_after = self.scaler._scale - # Compatible with paddlepaddle 2.6.0 using typo word. - if hasattr(self.scaler, "_cache_founf_inf"): - optimizer_was_run = not self.scaler._cache_founf_inf - else: - optimizer_was_run = not self.scaler._cache_found_inf - if not optimizer_was_run: - scale_before_value = scale_before.cpu().numpy() - scale_after_value = scale_after.cpu().numpy() - logger.warning( - f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}" - ) - elif isinstance(self.optimizer, HybridParallelOptimizer): - self.optimizer._step(parameters_list) - else: - self.optimizer.step() - - self.timers and self.timers("optimizer-step").stop() - - if optimizer_was_run: - self.lr_scheduler.step() - - if enable_release_grads and args.pipeline_parallel_degree > 1: - self.optimizer.clear_grad(set_to_zero=False) - for _, buffers in model._chunk_2_comm_buffers.items(): - for buffer in buffers: - buffer._clear_grad_storage() - else: - self.optimizer.clear_grad() - - self.callback_handler.on_optimizer_end( - args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None - ) + # used to create criterion for trainer, please refer to `create_criterion` + # for details. + loss_cls: type - self.state.global_step += 1 - self.state.epoch = epoch + (step + 1) / steps_in_epoch - self.control = self.callback_handler.on_step_end(args, self.state, self.control) - self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs) - self._print_timer() - step_control = 0 - else: - self.control = self.callback_handler.on_substep_end(args, self.state, self.control) - step_control += 1 - - if self.control.should_epoch_stop or self.control.should_training_stop: - # break - final_local_vars = locals() - for k in kwargs.keys(): - if k in final_local_vars: - kwargs[k] = final_local_vars[k] - return kwargs - self.timers and self.timers("read-data").start() - - final_local_vars = locals() - for k in kwargs.keys(): - if k in final_local_vars: - kwargs[k] = final_local_vars[k] - return kwargs - - -Trainer.init_train_model_opt = init_train_model_opt -Trainer.init_train_log = init_train_log -Trainer.init_train_state = init_train_state -Trainer.full_training_step = full_training_step - - -class PolicyTrainer(Trainer): def __init__( self, model: Union[PretrainedModel, nn.Layer] = None, @@ -495,7 +97,6 @@ def __init__( optimizers: Tuple[paddle.optimizer.Optimizer, paddle.optimizer.lr.LRScheduler] = (None, None), preprocess_logits_for_metrics: Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor] = None, ): - super().__init__( model, criterion, @@ -509,208 +110,547 @@ def __init__( optimizers, preprocess_logits_for_metrics, ) + # criterion is only used for non-PipelineParallel models. criterion is + # included in model for PipelineParallel. + if getattr(self, "loss_cls", None) and self.criterion is None: + self.criterion = self.create_criterion() + + self.use_fusemt = getattr(args, "use_fusemt", False) + # ablout 4s slower than infer generation without ema + self.use_ema = getattr(args, "use_ema", False) + self.shard_ema = getattr(args, "shard_ema", False) + self.offload_ema = getattr(args, "offload_ema", True) + self.ema_beta = getattr(args, "ema_beta", 0.992) + + def create_criterion(self): + """ + create loss using `loss_cls` for trainer. It would use a wrapped loss_cls + whose label arguments are merged into one argument, this is useful to + PipelineParallel and trainer.criterion which limit loss format. + """ + criterion = create_loss(self.loss_cls, self.model.config, self.args, merge_labels=True) + return criterion - def actor_loss_fn( - self, - log_probs: paddle.Tensor, - old_log_probs: paddle.Tensor, - advantages: paddle.Tensor, - mask: paddle.Tensor, - ) -> paddle.Tensor: - # policy gradient loss - ratio = paddle.exp(log_probs - old_log_probs) - pg_loss1 = -advantages * ratio - pg_loss2 = -advantages * paddle.clip( - ratio, - 1.0 - self.clip_range_ratio, - 1.0 + self.clip_range_ratio, - ) - return paddle.sum(paddle.maximum(pg_loss1, pg_loss2) * mask) / mask.sum() + def loss_identifier(self, inputs: Dict) -> str: + """ + Moreover, a model/StepTrainer instance may use a mixed loss which uses a + different loss for different step and inputs, while we often want to get + the separated loss metric. We use a callable discriminator using inputs + (dict) as arguments and returning corresponding loss name to identify + current loss. NOTE: please make the loss name ends with "_loss". `tr_loss` + is the default loss name used in trainer.train. + """ + return "tr_loss" - def compute_loss(self, model, inputs, return_outputs=False): + def set_eval_model(self, model): """ - How the loss is computed by Trainer. By default, all models return the loss in the first element. - Subclass and override for custom behavior. + To avoid eval/generation with PipelineParallel when training with PP, we + allow to use an extra eval model to do eval/generation, which would need + to reshard parameters and dispatch data according to model's distributed + topo. Currently, the eval model should cancel PP setting and keep the same + TP setting with training. """ - labels = inputs.get("labels", None) - if labels is not None: - labels = inputs.get("labels", None) - outputs = model(**inputs) - ptx_loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] - ptx_loss = self.ptx_coeff * ptx_loss - return ptx_loss - - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] - old_log_probs = inputs["old_log_probs"] - reward_advantages = inputs["reward_advantages"] - sequence_mask = inputs["sequence_mask"] - start = inputs["start"] - # NOTE: TensorParallel model requires non-Tensor inputs to be lists, thus - # do not use these inputs currently. - # use_cache = inputs["use_cache"] - # return_dict = inputs["return_dict"] - outputs = model( - input_ids=input_ids, - attention_mask=attention_mask, # use_cache=use_cache, return_dict=return_dict + if model is None: + logger.warning("use None to set eval model for trainer and it would be ignored") + return + else: + self._inner_eval_model = model + # bind a new comm group for eval model data dispatch + # param dispatch is binded in `InferEvalModel.enable` + hcg = fleet.get_hybrid_communicate_group() + sd_group = hcg.get_sharding_parallel_group() + dp_group = hcg.get_data_parallel_group() + global_rank = dist.get_rank() + eval_tp_size = max(model.config.tensor_parallel_degree, 1) + eval_tp_rank = max(model.config.tensor_parallel_rank, 0) + old_dp_workers = self.args.world_size // (max(sd_group.nranks, 1) * max(dp_group.nranks, 1)) + group_nums = self.args.logical_process_index // old_dp_workers * eval_tp_size + eval_tp_rank + self._data_trans_group = create_data_trans_group(global_rank, group_nums) + # just for compatiable with old code + self._policy_model_eval_group = self._data_trans_group + + def get_model(self, train=False): + """ + model visitor wrapps PipelineParalle and Inference model to do evaulation + and generation. + """ + if train: + return self.model_wrapped + model = getattr(self, "_eval_model", None) + if model is not None: + return model + inner_eval_model = getattr(self, "_inner_eval_model", None) + if (self.args.pipeline_parallel_degree > 1 and inner_eval_model is None) or isinstance( + inner_eval_model, fleet.model.PipelineParallel + ): + # Only accept wrapped model for pipeline_parallel mode + model = PipeEvalModel(self) + self._eval_model = model + else: + model = InferEvalModel(self) + self._eval_model = model + return model + + def get_train_step_vars(self, vars: Dict = None) -> Dict: + """ + NOTE: This is transparent to users. + When using multiple instances of StepTrainer collaborate to do one training + step, each should use its own vars such as loss/model/step_control which are + local vars in Trainer.train, we define these vars by `train_step_vars`. They + are vars needed by full_training_step for training control, as following: + tr_loss, model, epoch, step, step_control. + some vars such as `epoch` are meaningless, they are needed just because + full_training_step copies code from Trainer.train which is designed for + complete training process. + + return `train_step_vars` (dict). If not exists, create it first. If `vars` + is not None, update `train_step_vars` with it. + + TODO(guosheng): use namedtuple or dataclass to make it more readable. + """ + if not hasattr(self, "train_step_vars"): + # should be called after model is wrapped since the model field should + # use model_wrapped. + + assert self.model is not self.model_wrapped + self.train_step_vars = { + # meaningless vars can pass from outter, dummy value is enough + "epoch": 0, # meaningless for step training + "step": 0, # meaningless for step training + "steps_in_epoch": 100000, # meaningless for step training + "step_control": 0, # to control training process + "model": self.model_wrapped, + # "tr_loss": paddle.to_tensor(0.0), # lazy create + } + if vars: + self.train_step_vars.update(vars) + return self.train_step_vars + + @property + def loss_names(self): + if not hasattr(self, "_loss_names"): + self._loss_names = [var_name for var_name in self.get_train_step_vars() if var_name.endswith("_loss")] + assert len(self._loss_names) > 0 + return self._loss_names + + def full_training_step(self, **inputs) -> paddle.Tensor: + """ + Accept any valid key word arguments of model and loss as inputs, they + would be sent to model and then loss. Mostly it is similar to output from + data collator. + Return loss var. However when using PipelienParallel, the loss returned + is 0 when not reach accumulated step and the loss returned at accumulated + step is a mixed loss. We can use `get_step_loss` to get the actual loss. + """ + # if model has multi losses which are combined into one mixed criterion, + # loss statistic var may change for different training steps according + # to inputs. + train_step_vars = self.get_train_step_vars() + loss_name = self.loss_identifier(inputs) + loss_var = train_step_vars.get(loss_name, None) + # trainer.train use `tr_loss` as loss var to accumulate loss. + # NOTE: `tr_loss` in trainer.train not only accumulate mean loss for + # steps in one `gradient_accumulation_steps`, but also accumulate for + # one logging intervel which may contains more than one accumulated steps. + # However, in StepTrainer we only want to use `tr_loss` to accumulate + # mean loss for steps in a `gradient_accumulation_steps` range. As for + # logging intervel loss accumulation is not take into account here and + # should be considered in outter. + if loss_var is None: # the first step of current loss type + loss_var = paddle.to_tensor(0.0) + train_step_vars[loss_name] = loss_var + elif self.is_accumulation_step: # begin a new accumulation step intervel + for name in self.loss_names: + train_step_vars[name] = paddle.to_tensor(0.0) + loss_var = train_step_vars[loss_name] + + train_step_vars["tr_loss"] = loss_var + + new_train_step_vars = super().full_training_step(inputs, **train_step_vars) + + # minimally update + train_step_vars = self.get_train_step_vars( + {"step_control": new_train_step_vars["step_control"], loss_name: new_train_step_vars["tr_loss"]} ) + if loss_name != "tr_loss": + train_step_vars.pop("tr_loss") - logits = outputs["logits"] if isinstance(outputs, dict) else outputs - if isinstance(outputs, dict): - logits = outputs["logits"] - elif isinstance(outputs, tuple): - logits = outputs[0] + self.mark_step_loss(loss_name) - log_probs = gather_log_probabilities(logits[:, :-1], input_ids[:, 1:]) - actor_loss = self.actor_loss_fn( - log_probs[:, start:], - old_log_probs[:, start:], - reward_advantages, - sequence_mask[:, start:], + if self.use_ema and self.is_accumulation_step: + # TODO(guosheng): assume rollout next thus make ema weights on gpu, + # but may not, maybe need a way to specify it. + self.ema_update(beta=self.ema_beta, offload_ema=self.offload_ema, offload_model=not self.offload_ema) + + return train_step_vars[loss_name] + + def _prepare_inputs(self, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> Dict[str, Union[paddle.Tensor, Any]]: + """ + trainer.criterion only support criterion(prediction, labels), so we need + to reorganize the inputs to extract label data into one argument. This is + only used in non-PipelineParallel model training since loss is included + in PipelineLayer. + """ + inputs = super()._prepare_input(inputs) + if self.criterion is None or getattr(self.criterion, "label_names", None) is None: + return inputs + # criterion created by create_loss has `label_names` and `label_default_values` + label_names = self.criterion.__class__.label_names + # some data fields are used both in model and loss + shared_fields = set(["input_ids", "attention_mask"]) + labels = [] + for name in label_names: + if name not in inputs: + label = self.criterion.__class__.label_default_values.get(name, None) + elif name in shared_fields: + label = inputs[name] + else: + label = inputs.pop(name) + labels.append(label) + # "labels" is the pre-defined label name in Trainer + inputs["labels"] = labels + # NOTE: TensorParallel model requires non-Tensor inputs to be lists and + # broadcast them, thus do not or optionally use these inputs. labels use + # in criterion not send to model can workaround this. + return inputs + + def mark_step_loss(self, loss_name): + """ + NOTE: This is transparent to users. + When using a mixed loss we often want to get the separated loss metrics, + thus we mark loss type of each training step to separate them. This is + not necessary since the loss would be returnd after each training step. + However when using PipelienParallel, the loss returned is 0 when not reach + accumulated step and the loss returned at accumulated step is a mixed loss. + To separate loss metrics in PipelienParallel: + 1. We hack PipelineParallel._forward_step to record actual loss for each + step in a list (only in training and not in evaluation currently). + 2. We mark the loss type only once for each step using `loss_step_indice` + (dict), then wen can check out the corresponding loss metrics from the + loss list. + We assume a static order of multi-losses and mark the loss indice only once. + """ + self.loss_step_indice = getattr(self, "loss_step_indice", {}) + if loss_name not in self.loss_step_indice: + self.loss_step_indice[loss_name] = len(self.loss_step_indice) + + @paddle.no_grad() + def get_step_loss(self, loss_prefix: str = "", loss_accumulator: Dict = {}) -> Dict[str, paddle.Tensor]: + """ + Return a dict mapping loss name to value of current training step. This + is mainly to get loss for metric logging, and it would not affect the + training. This is mostly helpful to PipelienParallel with a mixed loss + in which the loss returned is 0 when not reach accumulated step and the + loss returned at accumulated step is a mixed loss. + NOTE: 1. Only when reaching accumulated step the losses returned are + accurate, and each loss is a mean loss of steps among one accumulated + steps range. + """ + if not self.is_accumulation_step: + msg = "The loss returned may not be accurate when not reaching accumulated step." + logger.error(msg) + model = self.get_model(train=True) + loss_dict = loss_accumulator if loss_accumulator else {} + if isinstance(model, fleet.model.PipelineParallel) and len(self.loss_names) > 1: + # NOTE: PipelineParallel only returns a accumulated loss after + # accumulated steps, which is a mixed loss of ppo-loss and + # ptx-loss. We hack PipelineParallel._forward_step to record + # loss metrics and postprocess the recorded losses here. + # Maybe better to make the last_stage worker log to reduce + # comm and for simplicity. + with paddle.no_grad(): + if model.is_pipeline_last_stage(): + # loss is 0D tensor, use stack rather than concat + mix_loss = paddle.stack(model._step_losses) + model._step_losses = None + else: + # The tessor shape is not policy_model.accumulate_steps + # (args.accu_steps) but policy_trainer.args.accu_steps, + # since policy_model is created with global pp_config + # using global args.accu_steps which is only half of + # policy_trainer.args.accu_steps, and indeed trainer hack + # model.accumulate_steps in training_pipeline_step to use + # trainer.args.accu_steps. The dtype is fp32(to be check), + # thus no need to broadcast. + mix_loss = paddle.empty(shape=[self.args.gradient_accumulation_steps], dtype=paddle.float32) + paddle.distributed.broadcast(mix_loss, src=model.pp_group.ranks[-1], group=model.pp_group) + for loss_name in self.loss_names: + # We assume a static order of multi-losses and mark the loss + # indice only once. + value = mix_loss[self.loss_step_indice[loss_name] :: len(self.loss_names)].mean() + loss_name = loss_prefix + loss_name if loss_prefix else loss_name + loss_dict[loss_name] = loss_dict[loss_name].add_(value) if loss_name in loss_dict else value + return loss_dict + elif isinstance(model, fleet.model.PipelineParallel): + model._step_losses = None + + for loss_name in self.loss_names: + value = self.get_train_step_vars()[loss_name] + loss_name = loss_prefix + loss_name if loss_prefix else loss_name + loss_dict[loss_name] = loss_dict[loss_name].add_(value) if loss_name in loss_dict else value + return loss_dict + + @property + def is_accumulation_step(self): + """Indicate whether accumulation steps' training is done.""" + return self.get_train_step_vars()["step_control"] == 0 + + def get_sharding_master_weight_structured_names(self, model, optimizer): + rank_param_names = [p.name for p in optimizer._rank2params[optimizer._sharding_rank]] + structured_names = [] + # for pipeline model, use `model.state_dict()` would auto map param name + # for name, p in model.named_parameters(): + for name, p in model.state_dict().items(): + if p.name in rank_param_names: + structured_names.append(name) + return structured_names + + def get_master_weight_state_dict(self, model, optimizer): + if self.amp_dtype in ["float16", "bfloat16"] and hasattr(optimizer, "_master_weights"): + master_weights = dict(optimizer._master_weights) + result = {} + # for pipeline model, use `model.state_dict()` would auto map param name + # for name, p in model.named_parameters(): + for name, p in model.state_dict().items(): + if p.name in master_weights: + result[name] = master_weights[p.name] + return result + else: + return model.state_dict() + + def ema_init(self, offload_ema=True, offload_model=False, shard_ema=True): + """should be called after model and optimizer are created and wrapped""" + self.ema_state_dict = {} + self.bak_state_dict = {} + hcg = fleet.get_hybrid_communicate_group() + sharding_size = hcg.get_sharding_parallel_world_size() + # NOTE: use optimizer.master_weight instead of model.state_dict to set + # ema_state_dict would make ema coupled with master_weight reshard. + structured_names = ( + self.get_sharding_master_weight_structured_names(self.model, self.optimizer) + if sharding_size > 1 and shard_ema + else None ) + # for pipeline model, use `model.state_dict()` would auto map param name + # for name, p in self.model.named_parameters(): + for name, p in self.model.state_dict().items(): + if structured_names is None or name in structured_names: + ema_p = p.detach().cast(dtype=paddle.float32) + if offload_ema: + ema_p = ema_p.pin_memory() + self.ema_state_dict[name] = ema_p + if offload_model: + cpu_p = p.pin_memory() + cpu_p._share_buffer_to(p) + self.bak_state_dict[name] = p + if getattr(self.model, "tie_word_embeddings", False): + raise NotImplementedError - return actor_loss + @paddle.no_grad() + def ema_update(self, beta=0.992, offload_ema=True, offload_model=False): + """ + This would be called automatically in `full_training_step` if `use_ema` + is True to update ema state when ending an accumulated step intervel. + """ + model_keys = list(self.ema_state_dict.keys()) + hcg = fleet.get_hybrid_communicate_group() + sharding_size = hcg.get_sharding_parallel_world_size() + trainer_state_dict = ( + self.get_master_weight_state_dict(self.model, self.optimizer) + if sharding_size > 1 and self.shard_ema + else self.model.state_dict() + ) + for key in model_keys: + if getattr(self.model, "tie_word_embeddings", False) and "lm_head" in key: + raise NotImplementedError + trainer_data = trainer_state_dict[key].cuda() + if trainer_data.dtype != paddle.float32: + # use model state dict instead of master weights + trainer_data = trainer_data.cast(dtype=paddle.float32) + ema_data = self.ema_state_dict[key].cuda() + # update ema & offload ema + ema_result = (beta * ema_data) + (1.0 - beta) * trainer_data + self.ema_state_dict[key] = ema_result.pin_memory() if offload_ema else ema_result + if offload_model: + cpu_p = trainer_data.pin_memory() + cpu_p._share_buffer_to(trainer_data) + if getattr(self.model, "tie_word_embeddings", False): + raise NotImplementedError + + def ema_apply(self): + """ + If use sharding and `shard_ema` is true, `ema_state_dict` only includes + sharded weights, thus we need the completed ema state to apply it to model + and ema would be coupled with reshard, then we need to reshard here. + """ + # TODO(guosheng): `bak_state_dict` is indeed trainer.model, allow to use + # a new model instead of trainer.model as target model. + # NOTE: if `shard_ema` is True, `ema_state_dict` is just a subset (sharded + # part) of model state_dict, and ema would coupled with reshard. + for k, v in self.bak_state_dict.items(): + # TODO(guosheng): reshard here + value = self.ema_state_dict[k].cuda().cast(dtype=v.dtype) + value._share_buffer_to(v) + + def ema_restore(self): + for k, v in self.bak_state_dict.items(): + value = v.cuda() + value._share_buffer_to(v) + if self.offload_ema: # ema weights always in pin_memory in fact + ema_v = self.ema_state_dict[k] + ema_value = ema_v.pin_memory() + ema_value._share_buffer_to(ema_v) + + +class ema(paddle.no_grad.__mro__[1]): + def __init__(self, trainer: StepTrainer): + self.trainer = trainer + + def __enter__(self): + trainer = self.trainer + if trainer.use_ema and not hasattr(trainer, "ema_state_dict"): + # call ema_init here since it should be called after model and + # optimizer are created and wrapped + trainer.ema_init( + offload_ema=trainer.offload_ema, offload_model=not trainer.offload_ema, shard_ema=trainer.shard_ema + ) + if self.trainer.use_ema: + self.trainer.ema_apply() + + def __exit__(self, *args): + if self.trainer.use_ema: + self.trainer.ema_restore() + + +class enable(paddle.no_grad.__mro__[1]): + """offload""" - def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs): + def __init__(self, *args): + self.objs = args + + def __enter__(self): + for obj in self.objs: + if hasattr(obj, "enable"): + obj.enable() + else: + reload_tensor_to_gpu(obj.state_dict()) + # offload_tensor_to_cpu/reload_tensor_to_gpu use non-blocking copy + # maybe overlap with compute later + if len(self.objs) > 0: + paddle.device.synchronize() + + def __exit__(self, *args): + for obj in self.objs: + if hasattr(obj, "disable"): + obj.disable() + else: + offload_tensor_to_cpu(obj.state_dict()) + # offload_tensor_to_cpu/reload_tensor_to_gpu use non-blocking copy + # maybe overlap with compute later + if len(self.objs) > 0: + paddle.device.synchronize() + + +class PolicyTrainer(StepTrainer): + loss_cls = RLHFPPOMixedLoss + + def loss_identifier(self, inputs: Dict) -> str: labels = inputs.get("labels", None) if labels is not None: # use ptx loss_name = "ptx_loss" else: loss_name = "actor_loss" - kwargs["model"] = kwargs.pop("policy_model") - kwargs["step_control"] = kwargs.pop("policy_step_control") - kwargs["tr_loss"] = kwargs.pop(loss_name) - kwargs = super().full_training_step(inputs, **kwargs) - kwargs["policy_model"] = kwargs.pop("model") - kwargs["policy_step_control"] = kwargs.pop("step_control") - kwargs[loss_name] = kwargs.pop("tr_loss") - return kwargs - - -class ValueTrainer(Trainer): - def __init__( - self, - model: Union[PretrainedModel, nn.Layer] = None, - criterion: nn.Layer = None, - args: TrainingArguments = None, - data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Dataset] = None, - eval_dataset: Union[Dataset, Dict[str, Dataset]] = None, - tokenizer: Optional[PretrainedTokenizer] = None, - compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, - callbacks: Optional[List[TrainerCallback]] = None, - optimizers: Tuple[paddle.optimizer.Optimizer, paddle.optimizer.lr.LRScheduler] = (None, None), - preprocess_logits_for_metrics: Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor] = None, - ): - - super().__init__( - model, - criterion, - args, - data_collator, - train_dataset, - eval_dataset, - tokenizer, - compute_metrics, - callbacks, - optimizers, - preprocess_logits_for_metrics, - ) + return loss_name + + +class ValueTrainer(StepTrainer): + loss_cls = RLHFValueLoss + # define loss name for logging + loss_identifier = lambda self, inputs: "reward_critic_loss" + + +class PPOMetric: + def set_metric_meta(self, use_ptx=True): + self.metric_names = [ + "train/" + name + for name in [ + "actor_loss", + "ptx_loss", + "reward_critic_loss", + "reward", + "kl_divergence", + "mean_generated_length", + "max_generated_length", + ] + ] - def critic_loss_fn( - self, - values: paddle.Tensor, - old_values: paddle.Tensor, - returns: paddle.Tensor, - mask: paddle.Tensor, - ) -> paddle.Tensor: - """Compute critic loss.""" - # TODO(guosheng): use paddle.clip when its min/max can support more than - # 0D Tensor - values_clipped = paddle.minimum( - paddle.maximum(values, old_values - self.clip_range_value), old_values + self.clip_range_value - ) - vf_loss1 = paddle.square(values - returns) - vf_loss2 = paddle.square(values_clipped - returns) - return 0.5 * paddle.sum(paddle.maximum(vf_loss1, vf_loss2) * mask) / mask.sum() + self.metric_ops = ["mean", "mean", "mean", "mean", "mean", "mean", "max"] + if not use_ptx: + self.metric_names.pop(1) + self.metric_ops.pop(1) + + def __init__(self, freq, use_stack=True, use_ptx=True): + self.set_metric_meta(use_ptx=use_ptx) + self.freq = freq + self.counter = 0 + self.use_stack = use_stack + if use_stack: + self.metrics = paddle.zeros([freq, len(self.metric_names)], dtype=paddle.float32) + else: + self.metrics = [None] * len(self.metric_names) + for i in range(len(self.metrics)): + self.metrics[i] = paddle.zeros([freq], dtype=paddle.float32) - def compute_loss(self, model, inputs, return_outputs=False): + @paddle.no_grad() + def update(self, metrics: Dict[str, paddle.Tensor]) -> Union[None, Dict[str, float]]: """ - How the loss is computed by Trainer. By default, all models return the loss in the first element. - Subclass and override for custom behavior. + If has updated for`freq` times then return metrics (results reduced from + all worker) and reset metric states, otherwise return `None`. """ - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] - old_reward_values = inputs["old_reward_values"] - reward_returns = inputs["reward_returns"] - sequence_mask = inputs["sequence_mask"] - start = inputs["start"] - # NOTE: TensorParallel model requires non-Tensor inputs to be lists, thus - # do not use these inputs currently. - # use_cache = inputs["use_cache"] - # return_dict = inputs["return_dict"] - outputs = model( - input_ids=input_ids, - attention_mask=attention_mask, # use_cache=use_cache, return_dict=return_dict - ) + for name in self.metric_names: + # PipelineParallel broadcast loss with shape [1] + if len(metrics[name].shape) != 0: + metrics[name] = metrics[name].squeeze() + if metrics[name].dtype != paddle.float32: + metrics[name] = metrics[name].cast(paddle.float32) + if self.use_stack: + self.metrics[self.counter] = paddle.stack([metrics[name] for name in self.metric_names]) + else: + for i, name in enumerate(self.metric_names): + self.metrics[i][self.counter] = metrics[name] + if self.counter + 1 == self.freq: + from paddlenlp.trainer.utils import distributed_concat + + metrics = distributed_concat(self.metrics) + out_metrics = {} + if self.use_stack: + mean_metric = metrics.mean(0) + max_metric = metrics.max(0) + for i, (name, op) in enumerate(zip(self.metric_names, self.metric_ops)): + if op == "max": + out_metrics[name] = max_metric[i].item() if self.use_stack else metrics[i].max().item() + else: + out_metrics[name] = mean_metric[i].item() if self.use_stack else metrics[i].mean().item() - # We don't use .loss here since the model may return tuples instead of ModelOutput. - reward_values = outputs["scores"] if isinstance(outputs, dict) else outputs - if isinstance(outputs, dict): - reward_values = outputs["scores"] - elif isinstance(outputs, tuple): - reward_values = outputs[0] - - reward_values = reward_values.squeeze(axis=-1)[:, :-1] - reward_critic_loss = self.critic_loss_fn( - reward_values[:, start:], - old_reward_values[:, start:], - reward_returns, - sequence_mask[:, start:], - ) + # reset + self.counter = 0 + if self.use_stack: + self.metrics.fill_(0.0) + else: + for i, name in enumerate(self.metric_names): + self.metrics[i].fill_(0.0) + return out_metrics - return reward_critic_loss - - def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs): - kwargs["model"] = kwargs.pop("value_model") - kwargs["step_control"] = kwargs.pop("value_step_control") - kwargs["tr_loss"] = kwargs.pop("reward_critic_loss") - kwargs = super().full_training_step(inputs, **kwargs) - kwargs["value_model"] = kwargs.pop("model") - kwargs["value_step_control"] = kwargs.pop("step_control") - kwargs["reward_critic_loss"] = kwargs.pop("tr_loss") - return kwargs - - -@contextmanager -def guard_set_args(args, arg_name_values): - for k, v in arg_name_values.items(): - old_value = getattr(args, k, None) - setattr(args, k, v) - arg_name_values[k] = old_value - yield - for k, v in arg_name_values.items(): - old_value = getattr(args, k) - setattr(args, k, v) - arg_name_values[k] = old_value - - -class MuteDefaultFlowCallback(TrainerCallback): - def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - control.should_save = False - control.should_evaluate = False - control.should_log = False - return control - - -def is_same_tokenizer( - tokenizer: PretrainedTokenizer, - other_tokenizer: PretrainedTokenizer, -) -> bool: - """Check if two tokenizers are the same.""" - return tokenizer is other_tokenizer or ( - tokenizer.__class__ == other_tokenizer.__class__ and tokenizer.get_vocab() == other_tokenizer.get_vocab() - ) + +def data_dispatch(fun): + def _impl(self, data): + gp = getattr(self.policy_trainer, "_data_trans_group", None) + data = data_group_split(data, group=gp) + data = fun(self, data) + data = data_group_merge(data, group=gp) + return data + + return _impl class PPOTrainer(Trainer): @@ -729,7 +669,14 @@ def __init__( optimizers: Tuple[paddle.optimizer.Optimizer, paddle.optimizer.lr.LRScheduler] = (None, None), preprocess_logits_for_metrics: Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor] = None, ): - with guard_set_args(args, {"recompute": False, "fp16_opt_level": "O1"}): + with guard_set_args( + args, + { + "recompute": False, + "fp16_opt_level": "O1", + "pipeline_parallel_degree": 1, # workaround for pipeline parallel model check + }, + ): # just used to create trival attrs might be used in the training # process of trainer, while changing some args to avoid model usage # in __init__ such as recompute and AMP-O2 @@ -751,7 +698,11 @@ def __init__( self.ptx_dataset = ptx_dataset self.eval_dataset = eval_dataset - (policy_model, reference_model, reward_model, value_model) = model + (policy_model, reference_model, reward_model, value_model, policy_model_eval, value_model_eval) = model + self._model_config = policy_model.config # use this to change flash attention dynamicly + self._policy_model_eval = policy_model_eval + self._value_model_eval = value_model_eval + # policy_tokenizer and value_tokenizer should be same (policy_tokenizer, reference_tokenizer, reward_tokenizer, value_tokenizer) = tokenizer @@ -795,14 +746,34 @@ def __init__( optimizers, preprocess_logits_for_metrics, ) + self.policy_trainer.set_eval_model(policy_model_eval) + self.value_trainer.set_eval_model(value_model_eval) + # disable inner trainers' callback/state/control + self.policy_trainer.add_callback(MuteDefaultFlowCallback) + self.value_trainer.add_callback(MuteDefaultFlowCallback) # use trainer for reference_model/reward_model to enable sharding stage-3 - # maybe we should allow models to use different dist strategies later - if ShardingOption.FULL_SHARD in args.sharding: - self.reference_trainer = Trainer( + # and PipelineParallel. maybe we should allow models to use different dist + # strategies later + + from paddle.distributed.fleet.meta_parallel import PipelineLayer + + # allow reference_model/reward_model to use different dist strategy + with guard_set_args( + args, + { + "recompute": False, + # "fp16_opt_level": "O1", + "pipeline_parallel_degree": args.pipeline_parallel_degree + if isinstance(reference_model, PipelineLayer) + else 1, # workaround for pipeline parallel model check + }, + ): + + self.reference_trainer = StepTrainer( reference_model, criterion, - args, + copy.deepcopy(args), data_collator, train_dataset, eval_dataset, @@ -812,10 +783,10 @@ def __init__( optimizers, preprocess_logits_for_metrics, ) - self.reward_trainer = Trainer( + self.reward_trainer = StepTrainer( reward_model, criterion, - args, + copy.deepcopy(args), data_collator, train_dataset, eval_dataset, @@ -827,11 +798,12 @@ def __init__( ) # TODO(guosheng): sharding stage3 should create master weight optionally # instead of creation and clear. - self.reference_trainer.init_train_model_opt(100, None, clear_master_weight=True) # dummy max_steps - self.reward_trainer.init_train_model_opt(100, None, clear_master_weight=True) # dummy max_steps - else: - self._reference_model = reference_model - self._reward_model = reward_model + from paddlenlp.trainer.trainer_utils import ShardingOption + + if args.pipeline_parallel_degree > 1 or ShardingOption.FULL_SHARD in args.sharding: + self.reference_trainer.init_train_model_opt(100, None, clear_master_weight=True) # dummy max_steps + self.reward_trainer.init_train_model_opt(100, None, clear_master_weight=True) # dummy max_steps + self.reference_model.eval() self.reward_model.eval() @@ -841,13 +813,14 @@ def __init__( self.reward_tokenizer = self.tokenizer self.generation_config = GenerationConfig( - max_length=self.args.max_length, + max_new_tokens=self.args.max_length, num_return_sequences=self.args.num_return_sequences, temperature=self.args.temperature, top_p=self.args.top_p, - # top_k=self.args.top_k, + top_k=0, # to disable top_k sampling, default is 50 repetition_penalty=self.args.repetition_penalty, do_sample=True, + # allow generation output to contain input trunc_input=False, bos_token_id=self.tokenizer.bos_token_id, eos_token_id=self.tokenizer.eos_token_id, @@ -855,10 +828,8 @@ def __init__( ) # Those value can be changed self.kl_coeff = self.args.kl_coeff - self.policy_trainer.clip_range_ratio = self.clip_range_ratio = self.args.clip_range_ratio self.clip_range_score = self.args.clip_range_score - self.value_trainer.clip_range_value = self.clip_range_value = self.args.clip_range_value - self.policy_trainer.ptx_coeff = self.ptx_coeff = self.args.ptx_coeff + self.ptx_coeff = self.args.ptx_coeff self.gamma = 1.0 self.gae_lambda = 0.95 @@ -868,63 +839,26 @@ def __init__( "DummyPPOModel", (object,), {"eval": lambda _: self.set_eval(), "train": lambda _: self.set_train()} ) self.model = self.model_wrapped = self.DummyPPOModel() - # self.optimizer = self.policy_trainer.optimizer - # self.scaler = self.reference_trainer.scaler = self.reward_trainer.scaler = None @property def reference_model(self): - # use model without Trainer - model = getattr(self, "_reference_model", None) - if model is not None: - return model - # use model with Trainer - if self.reference_trainer.args.pipeline_parallel_degree > 1: - # Only accept wrapped model for pipeline_parallel mode - model = self.reference_trainer.model_wrapped - else: - model = self.reference_trainer.model - return model + return self.reference_trainer.get_model(train=False) @property def reward_model(self): - # use model without Trainer - model = getattr(self, "_reward_model", None) - if model is not None: - return model - # use model with Trainer - if self.reward_trainer.args.pipeline_parallel_degree > 1: - # Only accept wrapped model for pipeline_parallel mode - model = self.reward_trainer.model_wrapped - else: - model = self.reward_trainer.model - return model + return self.reward_trainer.get_model(train=False) @property def actor_model(self): - if self.training: - return self.policy_trainer.model_wrapped - if self.policy_trainer.args.pipeline_parallel_degree > 1: - # Only accept wrapped model for pipeline_parallel mode - model = self.policy_trainer.model_wrapped - else: - model = self.policy_trainer.model - return model + return self.policy_trainer.get_model(train=self.training) @property def reward_critic_model(self): - if self.training: - return self.value_trainer.model_wrapped - if self.value_trainer.args.pipeline_parallel_degree > 1: - # Only accept wrapped model for pipeline_parallel mode - model = self.value_trainer.model_wrapped - else: - model = self.value_trainer.model - return model + return self.value_trainer.get_model(train=self.training) def set_train(self, mode: bool = True) -> None: """Set training mode for all models.""" if mode: - # self.is_in_train = True self.training = True self.actor_model.train() self.reward_critic_model.train() @@ -944,25 +878,19 @@ def prediction_step( prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[paddle.Tensor], Optional[paddle.Tensor], Optional[paddle.Tensor]]: - if self.args.pipeline_parallel_degree > 1: - # hack for pipeline mode - inputs = self._prepare_inputs(inputs) - return self.prediction_pipeline_step(model, inputs, prediction_loss_only, ignore_keys) - else: - inputs = self._prepare_inputs(inputs) + inputs = self._prepare_inputs(inputs) with paddle.no_grad(): with self.autocast_smart_context_manager(): seq = self.actor_model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], + position_ids=inputs["position_ids"] + if "position_ids" in inputs + else make_position_ids(inputs["attention_mask"]), generation_config=self.generation_config, synced_gpus=ShardingOption.FULL_SHARD in self.policy_trainer.args.sharding, )[0] - attention_mask = paddle.logical_and( - seq != self.tokenizer.pad_token_id, - seq != self.tokenizer.unk_token_id, - ) if self.reward_tokenizer is not self.tokenizer: reward_tokenize_output = batch_retokenize( input_ids=seq, @@ -972,14 +900,26 @@ def prediction_step( device=self.args.device, ) reward_input_ids = reward_tokenize_output["input_ids"] - reward_attention_mask = reward_tokenize_output["attention_mask"] else: reward_input_ids = seq - reward_attention_mask = attention_mask + reward_attention_mask = make_attention_mask( + seq, + pad_id=self.reward_tokenizer.pad_token_id, + unk_id=self.reward_tokenizer.unk_token_id, + causal_mask=False, + ) + reward_position_ids = make_position_ids(reward_attention_mask) + # unify PP with others since PP always return tuple reward_score = self.reward_model( - reward_input_ids, attention_mask=reward_attention_mask, return_dict=True - ).end_scores.squeeze(axis=-1) + reward_input_ids, + attention_mask=reward_attention_mask, + position_ids=reward_position_ids, + # return_dict=True, + )[ + 1 + ] # .end_scores + reward_score = reward_score.squeeze(axis=-1).cast(paddle.float32) # keep the first batch of eval output sequence to print and check prompt = self.tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=True) @@ -992,7 +932,7 @@ def prediction_step( # generateds.extend(generated) self._eval_seq = (prompt, generated, reward_score.tolist()) - return reward_score.cast(paddle.float32).mean(), None, None + return reward_score.mean(), None, None def evaluation_loop( self, @@ -1009,9 +949,20 @@ def evaluation_loop( ) self._eval_out_file = open(eval_out_file, "w") - output = super().evaluation_loop( - dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix, max_eval_iters - ) + # TODO(guosheng): use _inner_eval_model (if trainer has one) instead of + # original trainer model to eval, especially when using sharded EMA + # NOTE: use here rather than in prediction_step since actor_model would + # be set to eval out of prediction_step + with guard_set_args( + self.policy_trainer, # disable _inner_eval_model + { + "_eval_model": None, # otherwise would use cached _eval_model + "_inner_eval_model": None, # otherwise would use _inner_eval_model to create _eval_model + }, + ): + output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix, max_eval_iters + ) output.metrics[f"{metric_key_prefix}/reward"] = output.metrics.pop(f"{metric_key_prefix}_loss") columns = ["Prompt", "Generated", "Reward"] @@ -1041,23 +992,6 @@ def _save_checkpoint(self, model, metrics=None): with guard_set_args(self.value_trainer.args, {"output_dir": os.path.join(self.args.output_dir, "value")}): self.value_trainer._save_checkpoint(model, metrics) - # def _load_from_checkpoint(self, resume_from_checkpoint=None): - # with guard_set_args(self.policy_trainer.args, {"output_dir": os.path.join(self.args.output_dir, "policy")}): - # self.policy_trainer._load_from_checkpoint(resume_from_checkpoint) - # with guard_set_args(self.value_trainer.args, {"output_dir": os.path.join(self.args.output_dir, "value")}): - # self.value_trainer._load_from_checkpoint(resume_from_checkpoint) - - # def _load_optimizer_and_scheduler(self, checkpoint): - # # NOTE: `Trainer._load_optimizer_and_scheduler` would not seek the latest - # # state as in `_load_from_checkpoint``, and it just use `resume_from_checkpoint` - # # as value of `checkpoint` to load. - # self.policy_trainer._load_optimizer_and_scheduler( - # checkpoint if checkpoint is None else os.path.join(checkpoint, "policy") - # ) - # self.value_trainer._load_optimizer_and_scheduler( - # checkpoint if checkpoint is None else os.path.join(checkpoint, "value") - # ) - def init_train_model_opt( self: Trainer, max_steps: int, resume_from_checkpoint: bool = False, clear_master_weight: bool = False ) -> PretrainedModel: @@ -1080,23 +1014,24 @@ def init_train_model_opt( return policy_model, value_model def get_epoch_iterator(self): - # TODO(guosheng): support iter dataset - num_prompt_only_batches = len(self.prompt_only_dataloader) - num_ptx_batches = len(self.ptx_dataloader) - num_ptx_replicas = (num_prompt_only_batches + num_ptx_batches - 1) // num_ptx_batches - def gen_epoch_data(): for prompt_only_batch, ptx_batch in zip( self.prompt_only_dataloader, - itertools.chain.from_iterable([self.ptx_dataloader] * num_ptx_replicas), + itertools.cycle(self.ptx_dataloader), ): # generate batches self.set_eval() - rl_batches = self.split_rl_micro_batches(prompt_only_batch) + + with ema(self.policy_trainer), ema(self.value_trainer): + rl_batches = self.split_rl_micro_batches(prompt_only_batch) + + self.timers and self.timers("ptx-batch").start() if self.use_ptx: ptx_batches = self.split_ptx_micro_batches(ptx_batch) else: ptx_batches = [None for _ in range(len(rl_batches))] + self.timers and self.timers("ptx-batch").stop() + paddle.device.cuda.empty_cache() self.set_train() @@ -1108,31 +1043,48 @@ class EpochIterator: def __iter__(self): return gen_epoch_data() + def __len__(self): + return len(self.prompt_only_dataloader) * ( + self.args.update_iters + * self.args.per_device_prompt_batch_size + * self.args.num_return_sequences + // self.args.per_device_train_batch_size + ) + return EpochIterator() def init_train_num(self: Trainer, train_dataloader: DataLoader): args = self.args total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.dataset_world_size - - len_dataloader = len(train_dataloader) - num_train_sub_steps = ( - len_dataloader - * self.args.update_iters - * self.args.per_device_prompt_batch_size - * self.args.num_return_sequences - // self.args.per_device_train_batch_size - ) - num_update_steps_per_epoch = num_train_sub_steps // args.gradient_accumulation_steps - if args.max_steps > 0: - max_steps = args.max_steps - num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( - args.max_steps % num_update_steps_per_epoch > 0 + len_dataloader = None + if not self._is_iterable_dataset(self.train_dataset): + len_dataloader = len(train_dataloader) + num_train_sub_steps = ( + len_dataloader + * self.args.update_iters + * self.args.per_device_prompt_batch_size + * self.args.num_return_sequences + // self.args.per_device_train_batch_size ) + num_update_steps_per_epoch = num_train_sub_steps // args.gradient_accumulation_steps + num_examples = len(self.train_dataset) + if args.max_steps > 0: + max_steps = args.max_steps + num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( + args.max_steps % num_update_steps_per_epoch > 0 + ) + else: + max_steps = int(num_update_steps_per_epoch * args.num_train_epochs) + num_train_epochs = math.ceil(args.num_train_epochs) + num_train_samples = total_train_batch_size * max_steps else: - max_steps = int(num_update_steps_per_epoch * args.num_train_epochs) - num_train_epochs = math.ceil(args.num_train_epochs) - num_examples = num_train_samples = total_train_batch_size * max_steps + assert args.max_steps > 0 + max_steps = args.max_steps + num_train_epochs = sys.maxsize + num_update_steps_per_epoch = args.max_steps + num_examples = total_train_batch_size * args.max_steps + num_train_samples = args.max_steps * total_train_batch_size return ( total_train_batch_size, @@ -1144,6 +1096,18 @@ def init_train_num(self: Trainer, train_dataloader: DataLoader): num_train_samples, ) + def is_step_end(self): + # reach accumulation_steps, value trainer has the same step_control and + # gradient_accumulation_steps as PPO trainer. + # if (step_control + 1) % args.gradient_accumulation_steps == 0 + return self.value_trainer.is_accumulation_step + + def get_step_loss(self, loss_prefix: str = "") -> Dict: + rl_loss = self.policy_trainer.get_step_loss(loss_prefix) + value_loss = self.value_trainer.get_step_loss(loss_prefix) + rl_loss.update(value_loss) + return rl_loss + def train( self, resume_from_checkpoint: Optional[Union[str, bool]] = None, @@ -1166,15 +1130,16 @@ def train( with guard_set_args( args, { - "per_device_train_batch_size": self.args.per_device_prompt_batch_size - * self.args.num_return_sequences + "per_device_train_batch_size": 1 + if getattr(self.ptx_dataset, "is_intokens", False) + else self.args.per_device_prompt_batch_size * self.args.num_return_sequences }, ), guard_set_args( - self, {"train_dataset": self.ptx_dataset, "data_collator": self.ptx_dataset.get_collator(shift=True)} + self, {"train_dataset": self.ptx_dataset, "data_collator": self.ptx_dataset.get_collator()} ): self.ptx_dataloader = self.get_train_dataloader() else: - self.ptx_dataloader = DataLoader(DummyDataset(len(self.prompt_only_dataloader))) + self.ptx_dataloader = range(100) ( total_train_batch_size, len_dataloader, @@ -1186,16 +1151,8 @@ def train( ) = self.init_train_num(train_dataloader) # ##### model and optimizer related setting ##### - # policy_trainer/value_trainer only init train with init_train_model_opt, - # maybe more training setting used in full_training_step should be set here, - # such as trainer.control and trainer.state - # policy_model = self.policy_trainer.init_train_model_opt(max_steps, resume_from_checkpoint) - # value_model = self.value_trainer.init_train_model_opt(max_steps, resume_from_checkpoint) policy_model, value_model = self.init_train_model_opt(max_steps, resume_from_checkpoint) paddle.device.cuda.empty_cache() - # disable inner trainers' callback/state/control - self.policy_trainer.add_callback(MuteDefaultFlowCallback) - self.value_trainer.add_callback(MuteDefaultFlowCallback) # ##### traing statistic logging ##### # Number of trainable parameters only account for policy_model @@ -1224,39 +1181,11 @@ def train( self.control = self.callback_handler.on_train_begin(args, self.state, self.control) - actor_loss = paddle.to_tensor(0.0) - reward_critic_loss = paddle.to_tensor(0.0) - ptx_loss = paddle.to_tensor(0.0) - # used when logging and last step - self._total_actor_loss_scalar = 0.0 - self._total_reward_critic_loss_scalar = 0.0 - self._total_ptx_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step - - # train_step_kwargs is used to provide arguments more than model inputs - # for full_training_step which is copied from Trainer.train and needs - # these arguments to control training process. - train_step_kwargs = { - "ignore_keys_for_eval": None, # no need - # TODO(guosheng): commented args mean to resume data, not support yet - # "resume_from_checkpoint": resume_from_checkpoint, - # "train_dataloader": train_dataloader, - # "epochs_trained": epochs_trained, - # "steps_trained_in_current_epoch": steps_trained_in_current_epoch, - # "steps_trained_progress_bar": steps_trained_progress_bar, - "steps_in_epoch": steps_in_epoch, # to control training process - # the following args are corresponding to tr_loss and model used in - # Trainer.train, and they would be used as tr_loss and model in - # PolicyTranier and ValueTrainer. - "actor_loss": actor_loss, - "reward_critic_loss": reward_critic_loss, - "ptx_loss": ptx_loss, - "policy_model": policy_model, - "value_model": value_model, - } + metric = PPOMetric(freq=self.args.logging_steps, use_ptx=self.use_ptx) start_time = time.time() - self._globalstep_last_start_time = start_time # time.time() + self._globalstep_last_start_time = start_time # self.timers and self.timers("read-data").start() for epoch in range(epochs_trained, num_train_epochs): @@ -1265,36 +1194,52 @@ def train( ): train_dataloader.batch_sampler.set_epoch(epoch) - step_control = 0 # used in loop control, reset to 0 after every step - train_step_kwargs.update({"policy_step_control": step_control, "value_step_control": step_control}) self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) for step, inputs in enumerate(epoch_iterator): # self.timers and self.timers("read-data").stop() - os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step) - self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs) - # epoch, step and steps_in_epoch only mostly used in train_step by - # `self.state.epoch = epoch + (step + 1) / steps_in_epoch` if not - # resume data - train_step_kwargs.update({"epoch": epoch, "step": step}) + # os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step) + # self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs) rl_batch, ptx_batch = inputs # TODO(guosheng): make rl_step/ptx_step run with autocast_smart_context_manager - rl_info, train_step_kwargs = self.rl_step(rl_batch, **train_step_kwargs) - paddle.device.cuda.empty_cache() - if self.use_ptx: - ptx_info, train_step_kwargs = self.ptx_step(ptx_batch, **train_step_kwargs) - rl_info.update(ptx_info) + logger.info("Doing rl step...") + self.timers and self.timers("rl_step").start() + with self.enable(self.actor_model, self.policy_trainer.optimizer): + # with self.enable(self.value_trainer.optimizer): + with self.enable(): # put value optimizer guard in rl_step + rl_info = self.rl_step(rl_batch) paddle.device.cuda.empty_cache() + self.timers and self.timers("rl_step").stop() + + if self.use_ptx: + logger.info("Doing ptx step...") + self.timers and self.timers("ptx_step").start() + with guard_set_args( + self._model_config, + { + # "set_attn_func": True, + # "use_flash_attention": True + }, + ): + ptx_info = self.ptx_step(ptx_batch) + rl_info.update(ptx_info) + self.timers and self.timers("ptx_step").stop() + paddle.device.cuda.empty_cache() - self.state.global_step = self.value_trainer.state.global_step - self.state.epoch = self.value_trainer.state.epoch - if train_step_kwargs["value_step_control"] == 0: + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1) / steps_in_epoch + if self.is_step_end(): + rl_info.update(self.get_step_loss(loss_prefix="train/")) + rl_info = metric.update(rl_info) # on_step_end self.control = self.callback_handler.on_step_end(args, self.state, self.control) else: # on_sub_step_end self.control = self.callback_handler.on_substep_end(args, self.state, self.control) self._maybe_log_save_evaluate(rl_info, None, epoch, ignore_keys_for_eval, inputs=inputs) + self._print_timer() + if self.control.should_epoch_stop or self.control.should_training_stop: + break if step < 0: logger.warning( @@ -1307,6 +1252,7 @@ def train( self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) # argument model is not used in _maybe_log_save_evaluate, thus use None self._maybe_log_save_evaluate(rl_info, None, epoch, ignore_keys_for_eval, inputs=inputs) + self._print_timer() if self.control.should_training_stop: break @@ -1316,29 +1262,16 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, if self.control.should_log: logs: Dict[str, float] = {} - - for k, v in tr_loss.items(): - if isinstance(v, paddle.Tensor) and "lr" not in k and "max_generated_length" not in k: - v_scalar = self._nested_gather(v).mean().item() - if "train/actor_loss" == k and "train/ptx_loss" in tr_loss: - # use_ptx would double the gradient_accumulation_steps - # which causes actor_loss and ptx_loss reduced by half - v_scalar = v_scalar * 2 - elif "train/ptx_loss" == k: - # similar to actor_loss and should double, additionally - # it should be divided by ptx_coeff for logging - v_scalar = v_scalar * 2 / self.ptx_coeff - logs[k] = round(v_scalar / (self.state.global_step - self._globalstep_last_logged), 8) - v.subtract_(v) - attr_name = "_total_" + k.split("/")[-1] + "_scalar" - attr_value = getattr(self, attr_name, 0) - setattr(self, attr_name, attr_value + v_scalar) - elif "max_generated_length" in k: - v_scalar = self._nested_gather(v).max().item() - logs[k] = v_scalar - else: - logs[k] = float("{0:.3e}".format(v)) + # use_ptx would double the gradient_accumulation_steps which causes + # actor_loss and ptx_loss reduced by half. Moreover, ptx_loss should + # be divided by ptx_coeff for logging. + if "train/ptx_loss" in tr_loss: + tr_loss["train/actor_loss"] = tr_loss["train/actor_loss"] * 2 + tr_loss["train/ptx_loss"] = tr_loss["train/ptx_loss"] * 2 / self.ptx_coeff + logs.update(tr_loss) logs["global_step"] = int(self.state.global_step) + logs["train/actor_lr"] = float("{0:.3e}".format(self.policy_trainer._get_learning_rate())) + logs["train/reward_critic_lr"] = float("{0:.3e}".format(self.value_trainer._get_learning_rate())) total_train_batch_size = ( self.args.train_batch_size * self.args.gradient_accumulation_steps * self.args.dataset_world_size @@ -1377,11 +1310,15 @@ def add_kl_divergence_regularization( min=-self.clip_range_score, max=self.clip_range_score, ) - batch_size = log_probs.shape[0] - for i in range(batch_size): - end_index = sequence_mask[i].nonzero()[-1] - # rewards[i, end_index] += reward_clip[i] - rewards[i, end_index] = rewards[i, end_index] + reward_clip[i] + # TODO(guosheng): use scatter_add/put_along_axis + index = paddle.cumsum(sequence_mask.cast(paddle.int64), axis=-1).argmax(-1, keepdim=True) + rewards = paddle.put_along_axis(rewards, index, reward_clip.unsqueeze(axis=-1), axis=-1, reduce="add") + # batch_size = log_probs.shape[0] + # for i in range(batch_size): + # # print("="*20, sequence_mask[i]) + # end_index = sequence_mask[i].nonzero()[-1] + # # rewards[i, end_index] += reward_clip[i] + # rewards[i, end_index] = rewards[i, end_index] + reward_clip[i] return rewards @@ -1391,6 +1328,7 @@ def get_advantages_and_returns( rewards: paddle.Tensor, sequence_mask: paddle.Tensor, start: int, + use_tgt_len_return: bool = True, ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute advantages and returns using Generalized Advantage Estimation (GAE).""" # Modified from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py @@ -1399,6 +1337,16 @@ def get_advantages_and_returns( values = values * sequence_mask rewards = rewards * sequence_mask length = rewards.shape[-1] + if use_tgt_len_return and start > 0: + # consistent with Beaver + # values length is src+tgt-1, start is src-1, return length is tgt + pass + elif use_tgt_len_return: + # values length is tgt, start is 0, return length is tgt + assert start == 0 + else: + # values length is src+tgt-1, start is src-1, return length is src+tgt-1 + pass for t in reversed(range(start, length)): # pylint: disable=invalid-name next_values = values[:, t + 1] if t < length - 1 else 0.0 delta = rewards[:, t] + self.gamma * next_values - values[:, t] @@ -1406,82 +1354,98 @@ def get_advantages_and_returns( advantages_reversed.append(last_gae_lambda) advantages = paddle.stack(advantages_reversed[::-1], axis=1) returns = advantages + values[:, start:] + if not use_tgt_len_return: + advantages = paddle.concat( + [paddle.zeros([advantages.shape[0], start], dtype=advantages.dtype), advantages], -1 + ) + returns = paddle.concat([paddle.zeros([returns.shape[0], start], dtype=returns.dtype), returns], -1) return advantages.detach(), returns - def rl_step(self, rl_batch: Dict[str, paddle.Tensor], **kwargs) -> Dict[str, Any]: - prompt = rl_batch["prompt"] - old_log_probs = rl_batch["log_probs"] - ref_log_probs = rl_batch["ref_log_probs"] - rewards = rl_batch["rewards"] - old_reward_values = rl_batch["reward_values"] - input_ids = rl_batch["input_ids"] - attention_mask = rl_batch["attention_mask"] - - start = prompt.shape[-1] - 1 - sequence_mask = attention_mask[:, 1:] + def rl_step(self, rl_batch: Dict[str, paddle.Tensor]) -> Dict[str, Any]: + # inputs shared by policy and value trainer + input_ids = rl_batch["input_ids"] # length: src+tgt + attention_mask = rl_batch["attention_mask"] # length: src+tgt + position_ids = rl_batch["position_ids"] # length: src+tgt + sequence_mask = rl_batch["sequence_mask"] # length: src+tgt(-1) + # inputs used by policy trainer + old_log_probs = rl_batch["log_probs"] # length: src+tgt(-1) + reward_advantages = rl_batch["reward_advantages"] # length: src+tgt(-1) + # inputs used by value trainer + old_reward_values = rl_batch["reward_values"] # length: src+tgt(-1) + reward_returns = rl_batch["reward_returns"] # length: src+tgt(-1) - with paddle.no_grad(): - # maybe these two can also be put into rollout - old_rewards = self.add_kl_divergence_regularization( - prompt, - old_log_probs, - ref_log_probs, - rewards, - sequence_mask, - ) - reward_advantages, reward_returns = self.get_advantages_and_returns( - old_reward_values, - old_rewards, - sequence_mask, - start, - ) - - policy_trainer_inputs = { + value_trainer_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, - "old_log_probs": old_log_probs, - "reward_advantages": reward_advantages, + "position_ids": position_ids, + "old_reward_values": old_reward_values, + "reward_returns": reward_returns, "sequence_mask": sequence_mask, - "start": start, - "use_cache": False, - "return_dict": True, } - kwargs = self.policy_trainer.full_training_step(policy_trainer_inputs, **kwargs) + with self.enable(self.reward_critic_model, self.value_trainer.optimizer): + reward_critic_loss = self.value_trainer.full_training_step(**value_trainer_inputs) - value_trainer_inputs = { + policy_trainer_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, - "old_reward_values": old_reward_values, - "reward_returns": reward_returns, + "position_ids": position_ids, + "old_log_probs": old_log_probs, + "reward_advantages": reward_advantages, "sequence_mask": sequence_mask, - "start": start, - "use_cache": False, - "return_dict": True, } - kwargs = self.value_trainer.full_training_step(value_trainer_inputs, **kwargs) + actor_loss = self.policy_trainer.full_training_step(**policy_trainer_inputs) + # metric with paddle.no_grad(): - kl_divergence = ((old_log_probs - ref_log_probs) * sequence_mask)[:, start:].sum(axis=-1).mean() - mean_generated_length = sequence_mask[:, start:].cast(paddle.float32).sum(axis=-1).mean() - max_generated_length = sequence_mask[:, start:].cast(paddle.float32).sum(axis=-1).max() - - rewards = rewards.mean() + rewards = rl_batch["rewards"] + rewards = rewards.mean() + ref_log_probs = rl_batch["ref_log_probs"] + kl_divergence = ((old_log_probs - ref_log_probs) * sequence_mask).sum(axis=-1).mean() + mean_generated_length = sequence_mask.cast(paddle.float32).sum(axis=-1).mean() + max_generated_length = sequence_mask.cast(paddle.float32).sum(axis=-1).max() return { - "train/actor_loss": kwargs["actor_loss"], - "train/reward_critic_loss": kwargs["reward_critic_loss"], + # when using PipelienParallel, the loss returned is 0 when not reach + # accumulated step and the loss returned at accumulated step is a + # mixed loss. + "train/actor_loss": actor_loss, + "train/reward_critic_loss": reward_critic_loss, "train/reward": rewards, "train/kl_divergence": kl_divergence, "train/mean_generated_length": mean_generated_length, "train/max_generated_length": max_generated_length, - "train/actor_lr": self.policy_trainer._get_learning_rate(), - "train/reward_critic_lr": self.value_trainer._get_learning_rate(), - }, kwargs + } - def ptx_step(self, ptx_batch: Dict[str, paddle.Tensor], **kwargs) -> Dict[str, Any]: + def ptx_step(self, ptx_batch: Dict[str, paddle.Tensor]) -> Dict[str, Any]: """Perform a single update step with PTX loss.""" - kwargs = self.policy_trainer.full_training_step(ptx_batch, **kwargs) - return {"train/ptx_loss": kwargs["ptx_loss"]}, kwargs + # sft inputs use right padding, position_ids is optional + # ptx_batch["position_ids"] = ptx_batch.get( + # "position_ids", make_position_ids(ptx_batch["attention_mask"])) + ptx_loss = self.policy_trainer.full_training_step(**ptx_batch) + return { + "train/ptx_loss": ptx_loss, + } + + def enable(self, *args): + # note: must keep the same model since actor_model, reward_model etc. + # are property + enable_map = { + # maybe use `model: (pattern, enable_method, disable_method)`` + self.actor_model: "train_model", + self.reward_critic_model: "train_model", + self.reference_model: "freeze_model", + self.reward_model: "freeze_model", + self.policy_trainer.optimizer: "optimizer", + self.value_trainer.optimizer: "optimizer", + } + # if use an extra eval model to do eval/generation, switch on actor_model + # and reward_critic_model; otherwise no need to switch + if getattr(self.policy_trainer, "_inner_eval_model", None) is not None: + enable_map.pop(self.actor_model) + if getattr(self.value_trainer, "_inner_eval_model", None) is not None: + enable_map.pop(self.reward_critic_model) + objs = [arg for arg in args if enable_map.get(arg, "") in self.args.offload_level] + return enable(*objs) def split_ptx_micro_batches( self, @@ -1500,49 +1464,184 @@ def split_ptx_micro_batches( micro_batches.append(micro_batch) return micro_batches + # @staticmethod + # def data_dispatch(fun): + # def _impl(self, data): + # gp = getattr(self.policy_trainer, "_data_trans_group", None) + # data = data_group_split(data, group=gp) + # data = fun(self, data) + # data = data_group_merge(data, group=gp) + # return data + + # return _impl + + @paddle.no_grad() + @data_dispatch # 3.10 static methods are now callable as regular functions. def split_rl_micro_batches( self, - prompt_only_batch: PromptOnlyBatch, - ) -> List[PromptOnlyBatch]: + prompt_only_batch: Dict, + ) -> List[Dict]: """Split a batch of RL samples into micro-batches.""" total_batch_size = prompt_only_batch["input_ids"].shape[0] micro_batch_size = self.args.per_device_train_batch_size micro_batches = [] - for i in range(0, total_batch_size, micro_batch_size): - micro_batch = {} - micro_batch = map_structure( - lambda tensor: tensor[i : i + micro_batch_size], - prompt_only_batch, - ) - micro_batches.extend(self.rollout(micro_batch)) + + # TODO(guosheng): clean get_epoch_iterator: + # 1. scope guard for offload, we would split post_rollout into multiple + # sub-methods to offload in-time + # 2. decorate split_rl_micro_batches to automatically split/merge data + with self.enable(self.actor_model, self.reference_model): + # generate for multi batches and then disable FuseMT model + with infer_guard(self.policy_trainer): + # dist.barrier() + # print("="*20, "begin generate") + for i in range(0, total_batch_size, micro_batch_size): + micro_batch = {} + micro_batch = map_structure( + lambda tensor: tensor[i : i + micro_batch_size], + prompt_only_batch, + ) + micro_batches.extend(self.generate(micro_batch)) + # dist.barrier() + # paddle.device.cuda.synchronize() + # get log_probs for multi batches and then disable actor/refer rmodel + for micro_batch in micro_batches: + # position_ids is necessary for non-right padding + # If using right padding source + left padding target, make padding positions + # in source be 0, since reward model use position_ids plus with padding size + # (number of 0s) in source to calculate end offsets. + micro_batch["position_ids"] = make_position_ids(micro_batch["attention_mask"]) + micro_batch.update(self.rollout_logprob(**micro_batch)) + # print("="*20, "micro_batch", micro_batch) + + # get reward/value for multi batches and then disable reward/value model + with self.enable(self.reward_critic_model, self.reward_model): + for micro_batch in micro_batches: + micro_batch.update(self.rollout_reward_value(**micro_batch)) + + # + micro_batches = [self.normalize_data(micro_batch, use_tgt_len_value=False) for micro_batch in micro_batches] + # size of micro_batches (num of training batch) would be: + # per_device_prompt_batch_size * num_return_sequences // per_device_train_batch_size + # micro_batches = [self.post_rollout(**micro_batch) for micro_batch in micro_batches] return micro_batches @paddle.no_grad() - def rollout(self, prompt_only_batch: PromptOnlyBatch) -> List[Dict[str, Any]]: + def generate(self, prompt_only_batch: Dict) -> List[Dict[str, Any]]: """Rollout a batch of experiences.""" input_ids = prompt_only_batch["input_ids"] - # NOTE: generation output of paddlenlp do not contain prompt, we should - # change sequences here. + attention_mask = prompt_only_batch["attention_mask"] + + self.timers and self.timers("actor-model-generate").start() sequences = self.actor_model.generate( input_ids=input_ids, - attention_mask=prompt_only_batch["attention_mask"], + attention_mask=attention_mask, + position_ids=prompt_only_batch["position_ids"] + if "position_ids" in prompt_only_batch + else make_position_ids(attention_mask), generation_config=self.generation_config, synced_gpus=ShardingOption.FULL_SHARD in self.policy_trainer.args.sharding, )[0] + + self.timers and self.timers("actor-model-generate").stop() sequences = sequences.reshape([input_ids.shape[0], self.args.num_return_sequences, -1]).transpose([1, 0, 2]) + # prompt, sequence, attention_mask return [ - self.post_rollout( - input_ids, - seq, - attention_mask=paddle.logical_and( - seq != self.tokenizer.pad_token_id, - seq != self.tokenizer.unk_token_id, + { + "prompt": input_ids, + "input_ids": seq, # "sequence": + "attention_mask": make_attention_mask( + seq, + pad_id=self.tokenizer.pad_token_id, + unk_id=self.tokenizer.unk_token_id, + causal_mask=False, ), - ) + # "sequence_mask": make_attention_mask( + # seq, + # pad_id=self.tokenizer.pad_token_id, + # unk_id=self.tokenizer.unk_token_id, + # causal_mask=False, + # ), + } for seq in sequences ] + @paddle.no_grad() + def rollout_logprob( + self, input_ids: paddle.Tensor, attention_mask: paddle.Tensor, position_ids: paddle.Tensor = None, **kwargs + ) -> Dict[str, paddle.Tensor]: + # pipe model outputs a logits tensor with LMHead, while non-pipe model + # outputs a tuple with logits tensor as the only one element. + logits = self.actor_model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + # return_dict=True, + ) # .logits + if not isinstance(logits, paddle.Tensor): + logits = logits[0] + ref_logits = self.reference_model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + # return_dict=True, + ) # .logits + if not isinstance(ref_logits, paddle.Tensor): + ref_logits = ref_logits[0] + log_probs = gather_log_probabilities(logits[:, :-1], input_ids[:, 1:]) + ref_log_probs = gather_log_probabilities(ref_logits[:, :-1], input_ids[:, 1:]) + return {"log_probs": log_probs, "ref_log_probs": ref_log_probs} + + @paddle.no_grad() + def rollout_reward_value( + self, input_ids: paddle.Tensor, attention_mask: paddle.Tensor, position_ids: paddle.Tensor = None, **kwargs + ) -> Dict[str, paddle.Tensor]: + if self.reward_tokenizer is not self.tokenizer: + # right padding + reward_tokenize_output = batch_retokenize( + input_ids, + src_tokenizer=self.tokenizer, + dest_tokenizer=self.reward_tokenizer, + skip_special_tokens=True, + ) + reward_input_ids = reward_tokenize_output["input_ids"] + reward_attention_mask = make_attention_mask( + reward_input_ids, + pad_id=self.reward_tokenizer.pad_token_id, + unk_id=self.reward_tokenizer.unk_token_id, + causal_mask=False, + ) + reward_position_ids = make_position_ids(reward_attention_mask) + else: + # for text in self.tokenizer.batch_decode(input_ids, skip_special_tokens=False): + # print(text) + reward_input_ids = input_ids + reward_attention_mask = attention_mask + reward_position_ids = position_ids + reward_score = self.reward_model( + reward_input_ids, + attention_mask=reward_attention_mask, + position_ids=reward_position_ids, + # return_dict=True, + )[ + 1 + ] # .end_scores + + reward_value = self.reward_critic_model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + # return_dict=True, + )[ + 0 + ] # .scores + reward_score = reward_score.squeeze(axis=-1) + reward_value = reward_value.squeeze(axis=-1) + + reward_value = reward_value[:, :-1] + return {"rewards": reward_score, "reward_values": reward_value} + @paddle.no_grad() def post_rollout( self, @@ -1551,6 +1650,7 @@ def post_rollout( attention_mask: paddle.Tensor, ) -> Dict[str, Any]: if self.reward_tokenizer is not self.tokenizer: + # right padding reward_tokenize_output = batch_retokenize( sequence, src_tokenizer=self.tokenizer, @@ -1560,34 +1660,143 @@ def post_rollout( reward_seq = reward_tokenize_output["input_ids"] reward_attention_mask = reward_tokenize_output["attention_mask"] else: + # actor_model_in_use gen + # for text in self.tokenizer.batch_decode(sequence, skip_special_tokens=True): + # print(text) reward_seq = sequence reward_attention_mask = attention_mask - + # position_ids is necessary for non-right padding + # If using right padding source + left padding target, make padding positions + # in source be 0, since reward model use position_ids plus with padding size + # (number of 0s) in source to calculate end offsets. + position_ids = make_position_ids(attention_mask) + + # pipe model outputs a logits tensor with LMHead, while non-pipe model + # outputs a tuple with logits tensor as the only one element. + self.timers and self.timers("actor-model-logit").start() logits = self.actor_model( sequence, attention_mask=attention_mask, - return_dict=True, - ).logits - ref_logits = self.reference_model(sequence, attention_mask=attention_mask, return_dict=True).logits + position_ids=position_ids, + # return_dict=True, + ) # .logits + self.timers and self.timers("actor-model-logit").stop() + if not isinstance(logits, paddle.Tensor): + logits = logits[0] + self.timers and self.timers("reference-model-logit").start() + ref_logits = self.reference_model( + sequence, + attention_mask=attention_mask, + position_ids=position_ids, + # return_dict=True, + ) # .logits + self.timers and self.timers("reference-model-logit").stop() + if not isinstance(ref_logits, paddle.Tensor): + ref_logits = ref_logits[0] + + self.timers and self.timers("reward-model-score").start() + reward_score = self.reward_model( + reward_seq, + attention_mask=reward_attention_mask, + position_ids=position_ids, + # return_dict=True, + )[ + 1 + ] # .end_scores - reward_score = self.reward_model(reward_seq, attention_mask=reward_attention_mask, return_dict=True).end_scores reward_value = self.reward_critic_model( sequence, attention_mask=attention_mask, - return_dict=True, - ).scores - + position_ids=position_ids, + # return_dict=True, + )[ + 0 + ] # .scores reward_score = reward_score.squeeze(axis=-1) - reward_value = reward_value.squeeze(axis=-1)[:, :-1] + reward_value = reward_value.squeeze(axis=-1) + self.timers and self.timers("reward-model-score").stop() + reward_value = reward_value[:, :-1] log_probs = gather_log_probabilities(logits[:, :-1], sequence[:, 1:]) ref_log_probs = gather_log_probabilities(ref_logits[:, :-1], sequence[:, 1:]) - return { + rollout_data = { "prompt": prompt, - "log_probs": log_probs, - "ref_log_probs": ref_log_probs, - "rewards": reward_score, - "reward_values": reward_value, "input_ids": sequence, + "position_ids": position_ids, "attention_mask": attention_mask, + "rewards": reward_score, + "reward_values": reward_value, + "log_probs": log_probs, + "ref_log_probs": ref_log_probs, } + rollout_data = self.normalize_data(rollout_data, use_tgt_len_value=False) + return rollout_data + + @paddle.no_grad() + def normalize_data( + self, + rl_batch: Dict[str, paddle.Tensor], + use_tgt_len_value: bool = False, + ) -> Dict[str, Any]: + """ + data dispatch comm among devices needs padding, while the lengths of + all data fields are different and related, and it's hard to pad. + """ + prompt = rl_batch["prompt"] # length: src + attention_mask = rl_batch["attention_mask"] # length: src + tgt + if len(attention_mask.shape) == 4: + # use padding mask instead of causal mask + attention_mask = rl_batch["sequence_mask"] # length: src + tgt + old_log_probs = rl_batch["log_probs"] # length: src + tgt -1 + ref_log_probs = rl_batch["ref_log_probs"] # length: src + tgt -1 + rewards = rl_batch["rewards"] # length: 1 + old_reward_values = rl_batch["reward_values"] # length: src + tgt -1 + + # Beaver uses label data with target length, while we do not slice from + # inputs and use label data with target length: + # 1. Sometimes we cannot use label data with target length, mostly because + # it is hard to pad acorss batches. Think in some cases one batch might + # have the longest prompt+target length but the shortest target lengh, which + # might cause mismatch between inputs with prompt+target length and labels + # with target length. Padding acorss batches is needed in PP and data comm. + # 2. Additionally, when using flash_attn with casual mask and right padding + # we cannot use label data with target length. + start = prompt.shape[-1] - 1 + # sequence_mask is for label masking, make source be masked out + # clone to avoid to change attention_mask + sequence_mask = attention_mask[:, 1:].clone() # length: src + tgt -1 + sequence_mask[:, :start] = False + if use_tgt_len_value: + ref_log_probs = ref_log_probs[:, start:] + old_log_probs = old_log_probs[:, start:] + old_reward_values = old_reward_values[:, start:] + sequence_mask = sequence_mask[:, start:] + old_rewards = self.add_kl_divergence_regularization( + None, # prompt, + old_log_probs, + ref_log_probs, + rewards, + sequence_mask, + ) # length: tgt if use_tgt_len_value src + tgt -1 + reward_advantages, reward_returns = self.get_advantages_and_returns( + old_reward_values, + old_rewards, + sequence_mask, + start=0 if use_tgt_len_value else start, + use_tgt_len_return=use_tgt_len_value, + ) # length: tgt if use_tgt_len_value src + tgt -1 + + rl_batch.update( + { + "log_probs": old_log_probs, + "reward_values": old_reward_values, + "reward_advantages": reward_advantages, + "reward_returns": reward_returns, + "sequence_mask": sequence_mask, + "ref_log_probs": ref_log_probs, + "rewards": rewards, + } + ) + # pop out to reduce data dispatch comm overhead + rl_batch.pop("prompt") + return rl_batch diff --git a/examples/RLHF/reward_config.json b/examples/RLHF/reward_config.json index d5dde7b6892c..bfe5b2d1b523 100644 --- a/examples/RLHF/reward_config.json +++ b/examples/RLHF/reward_config.json @@ -1,5 +1,5 @@ { - "model_name_or_path": "/root/paddlejob/workspace/guosheng/alpaca-7b-reproduced/", + "model_name_or_path": "PKU-Alignment/alpaca-7b-reproduced", "train_datasets": "PKU-SafeRLHF-30K/train", "eval_datasets": "PKU-SafeRLHF-30K/test", "output_dir": "/root/paddlejob/workspace/guosheng/checkpoints/llama_sft_ckpts-test", diff --git a/examples/RLHF/reward_main.py b/examples/RLHF/reward_main.py index f52415452b2e..6bcd4485ccfa 100644 --- a/examples/RLHF/reward_main.py +++ b/examples/RLHF/reward_main.py @@ -32,6 +32,14 @@ from paddlenlp.transformers import AutoConfig, AutoTokenizer, LlamaTokenizer from paddlenlp.utils.log import logger +# launch would unset http_proxy +# export https_proxy=http://172.19.57.45:3128 +# os.environ["http_proxy"] = "http://172.19.57.45:3128" +# os.environ["https_proxy"] = "http://172.19.57.45:3128" +os.environ["http_proxy"] = "http://10.162.37.16:8128" +os.environ["https_proxy"] = "http://10.162.37.16:8128" + + @dataclass class TrainingArguments(TrainingArguments): loss_type: Literal["token-wise", "sequence-wise"] = field( diff --git a/examples/RLHF/reward_trainer.py b/examples/RLHF/reward_trainer.py index 42f11956d4fd..a542d55942a2 100644 --- a/examples/RLHF/reward_trainer.py +++ b/examples/RLHF/reward_trainer.py @@ -37,9 +37,9 @@ speed_metrics = trainer.speed_metrics -def patch_speed_metrics(split, start_time, num_samples=None, num_steps=None): +def patch_speed_metrics(split, start_time, num_samples=None, num_steps=None, seq_length=None): # split: interval, train, eval, test - result = speed_metrics(split, start_time, num_samples, num_steps) + result = speed_metrics(split, start_time, num_samples, num_steps, seq_length) if split not in ["train", "interval"]: return result # accuracy diff --git a/examples/RLHF/run.sh b/examples/RLHF/run.sh new file mode 100644 index 000000000000..a4bdca9e974e --- /dev/null +++ b/examples/RLHF/run.sh @@ -0,0 +1,15 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +PYTHONPATH=../../ GLOG_minloglevel=2 python3.10 -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ppo_main.py ppo_config.json diff --git a/examples/RLHF/tests/run_model.py b/examples/RLHF/tests/run_model.py new file mode 100644 index 000000000000..19bb2176d5e2 --- /dev/null +++ b/examples/RLHF/tests/run_model.py @@ -0,0 +1,125 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from dataclasses import dataclass, field + +import numpy +import paddle +from paddle.distributed import fleet +from ppo_trainer import Trainer, data_group_merge, data_group_split, group_rank_guard + +from paddlenlp.trainer import PdArgumentParser, TrainingArguments +from paddlenlp.transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForCausalLMPipe, +) + + +@dataclass +class ModelArgument: + model_name_or_path: str = field( + default=None, metadata={"help": "Build-in pretrained model name or the path to local model."} + ) + test_mode: str = field(default="export", metadata={"help": "export data_split or rank_guard."}) + + +def test_group_rank_guard(group): + @group_rank_guard(group=group, rank=0) + def func(): + tensor = paddle.randn([4, 64]) + return tensor + + t = func() + ret = [] + paddle.distributed.stream.all_gather(ret, t, group=group) + + for x in ret: + assert x._md5sum() == t._md5sum(), f"{x} {t}" + + +def main(): + # Arguments + parser = PdArgumentParser((ModelArgument, TrainingArguments)) + model_args, training_args = parser.parse_args_into_dataclasses() + + hcg = fleet.get_hybrid_communicate_group() + pp_group = hcg.get_pipe_parallel_group() + tp_group = hcg.get_model_parallel_group() + + if model_args.test_mode == "rank_guard": + test_group_rank_guard(tp_group) + return 0 + + model_config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + tensor_parallel_output=False, + tensor_parallel_degree=training_args.tensor_parallel_degree, + tensor_parallel_rank=training_args.tensor_parallel_rank, + dtype="float32", + ) + + model_class = AutoModelForCausalLM + if training_args.pipeline_parallel_degree > 1: + model_class = AutoModelForCausalLMPipe + + actor_model = model_class.from_pretrained( + model_args.model_name_or_path, + config=model_config, + ) + + if True: # test export_evaluate_model + # 随机初始化 + config = copy.deepcopy(model_config) + if training_args.pipeline_parallel_degree <= 1: + config.tensor_parallel_degree = -1 + config.tensor_parallel_rank = 0 + + actor_eval_model = AutoModelForCausalLM.from_config(config) + # ground truth模型 + actor_gt_model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config) + + trainer = Trainer( + model=actor_model, + args=training_args, + ) + trainer.export_evaluate_model(actor_model, actor_eval_model) + + gp_state = actor_gt_model.state_dict() + export_state = actor_eval_model.state_dict() + + for k, v in gp_state.items(): + assert ( + v._md5sum() == export_state[k]._md5sum() + ), f"{k} groud_truth: {v.shape}, export: {export_state[k].shape}" + + split_group = tp_group + if training_args.pipeline_parallel_degree > 1: + split_group = pp_group + + input_ids = paddle.randint(low=1, high=50, shape=[8, 64]) + paddle.distributed.broadcast(input_ids, src=0) + + split_input_ids = data_group_split(input_ids, group=split_group) + ret = actor_eval_model(input_ids=split_input_ids, return_dict=True) + eval_loggits = data_group_merge(ret.logits, group=split_group) + + gt_ret = actor_gt_model(input_ids=input_ids, return_dict=True) + gt_loggits = gt_ret.logits + numpy.testing.assert_almost_equal(eval_loggits.numpy(), gt_loggits.numpy(), decimal=5) + + +if __name__ == "__main__": + main() diff --git a/examples/RLHF/tests/test_export.py b/examples/RLHF/tests/test_export.py new file mode 100644 index 000000000000..6254427e89cd --- /dev/null +++ b/examples/RLHF/tests/test_export.py @@ -0,0 +1,68 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from tests.parallel_launch import TestMultipleGpus + +environment_variables = { + "NCCL_ALGO": "Tree", + "NVIDIA_TF32_OVERRIDE": "0", + "NCCL_IB_TIMEOUT": "22", + "NCCL_DEBUG": "INFO", + "FLAGS_embedding_deterministic": "1", + "FLAGS_cudnn_deterministic": "1", + "Flags_mp_aysnc_allreduce": "1", + "Flags_skip_mp_c_identity": "1", + "FLAGS_shard_norm_align_dp": "0", + "FLAGS_shard_use_reduce": "1", + "test_ci_no_save_model": "1", +} + + +class TestExportEvalModel(TestMultipleGpus): + def setUp(self): + os.environ.update(environment_variables) + super().setUp() + + def test_pptp_to_tp(self): + config = { + "output_dir": "./tmp", + "model_name_or_path": "__internal_testing__/tiny-random-llama", + "tensor_parallel_degree": 2, + "pipeline_parallel_degree": 2, + } + scripts = "tests/run_model.py" + self.run_4gpu(scripts, **config) + + def test_tp_to_single(self): + config = { + "output_dir": "./tmp", + "model_name_or_path": "__internal_testing__/tiny-random-llama", + "tensor_parallel_degree": 2, + "pipeline_parallel_degree": 1, + } + scripts = "tests/run_model.py" + self.run_2gpu(scripts, **config) + + def test_group_rank_guard(self): + config = { + "output_dir": "./tmp", + "model_name_or_path": "__internal_testing__/tiny-random-llama", + "tensor_parallel_degree": 2, + "pipeline_parallel_degree": 1, + "test_mode": "rank_guard", + } + scripts = "tests/run_model.py" + self.run_2gpu(scripts, **config) diff --git a/examples/RLHF/trainer_utils.py b/examples/RLHF/trainer_utils.py new file mode 100644 index 000000000000..865d34cea653 --- /dev/null +++ b/examples/RLHF/trainer_utils.py @@ -0,0 +1,662 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import os +import time +from contextlib import contextmanager +from typing import Dict + +import numpy as np +import paddle +import tqdm +from paddle.distributed import fleet +from paddle.io import DataLoader + +from paddlenlp.generation.utils import GenerationMixin +from paddlenlp.trainer.trainer import ( + TRAINER_STATE_NAME, + HybridParallelOptimizer, + NlpDistributedBatchSampler, + ShardingOption, + Trainer, + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, + _obtain_optimizer_parameters_list, + distributed_file, + distributed_isfile, + fused_allreduce_gradients, + logger, + reshard_util, + split_inputs_sequence_dim, +) +from paddlenlp.transformers import BatchEncoding, PretrainedModel, PretrainedTokenizer +from paddlenlp.transformers.configuration_utils import PretrainedConfig +from paddlenlp.transformers.model_outputs import ModelOutput +from paddlenlp.transformers.tokenizer_utils_base import ( + PaddingStrategy, + TruncationStrategy, +) + + +# ########## patches for Trianer ########## +def init_train_model_opt( + self: Trainer, max_steps: int, resume_from_checkpoint: bool = False, clear_master_weight: bool = False +) -> PretrainedModel: + # Copy of model/optimizer init and resuming related code in `Trainer.train`. + # NOTE: this `_load_from_checkpoint` is indeed to load model states in the + # following elif-else branches, though they are apart away in `Trainer.train`. + if not self.args.should_load_sharding_stage1_model: + self._load_from_checkpoint(resume_from_checkpoint) + + # delay_optimizer_creation = ( + # self.sharding is not None + # and ShardingOption.SHARD_OP in self.args.sharding + # ) + delay_optimizer_creation = False + + if not delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + if self.args.should_load_sharding_stage1_model: + model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint) + elif self.args.should_save_sharding_stage1_model: + # In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model. + # In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks. + model = self._wrap_model(self.model_wrapped) + if self.sharding_io is not None: + assert delay_optimizer_creation is False, "delay_optimizer_creation should be False" + # the self.optimizer should be wrapped and it is done in _wrap_model + self.sharding_io.set_optimizer(self.optimizer) + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + self._load_optimizer_and_scheduler(resume_from_checkpoint) + else: + model = self._wrap_model(self.model_wrapped) + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + self._load_optimizer_and_scheduler(resume_from_checkpoint) + + if ShardingOption.FULL_SHARD in self.args.sharding and clear_master_weight: + # for inference model to use Trainer sharding stage3, clear master_weight + # which is created in GroupShardedStage3.__init__ + self.optimizer._master_weights = None + + if self.args.device == "npu" and self.args.flatten_param_grads: + from .plugins.npu_plugin import npu_accelerate_plugin + + npu_accelerate_plugin(self.optimizer) + + return model + + +def init_train_state( + self: Trainer, + resume_from_checkpoint: bool, + train_dataloader: DataLoader, + max_steps: int, + num_train_epochs: int, + num_update_steps_per_epoch: int, +): + args = self.args + + self.state = TrainerState() + self.state.epoch = 0 + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + steps_trained_progress_bar = None + + # Check if continuing training from a checkpoint + if resume_from_checkpoint is not None and distributed_isfile( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ): + self.state = TrainerState.load_from_json( + distributed_file(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + ) + epochs_trained = self.state.global_step // num_update_steps_per_epoch + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info(f" Continuing training from global step {self.state.global_step}") + if not args.ignore_data_skip: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " + "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` " + "flag to your launch command, but you will resume the training on data already seen by your model." + ) + if self.is_local_process_zero() and not args.disable_tqdm: + steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) + steps_trained_progress_bar.set_description("Skipping the first batches") + if not args.ignore_data_skip: + if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance( + train_dataloader.batch_sampler, NlpDistributedBatchSampler + ): + consumed_samples = ( + self.state.global_step + * args.train_batch_size + * args.gradient_accumulation_steps + * args.dataset_world_size + ) + train_dataloader.batch_sampler.set_epoch(consumed_samples=consumed_samples) + logger.info(f"Set DistributedBatchSampler consumed_samples to {consumed_samples}") + + self.state.max_steps = int(max_steps) + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() + + return epochs_trained, steps_trained_in_current_epoch, steps_trained_progress_bar + + +def init_train_log( + self: Trainer, + num_examples: int, + num_train_epochs: int, + total_train_batch_size: int, + max_steps: int, + num_train_samples: int, + model: PretrainedModel, +): + args = self.args + + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps:,}") + logger.info(f" Total num train samples = {num_train_samples:,}") + # per_device_trainable_numel = sum(p.numel().item() for p in model.parameters() if not p.stop_gradient) + # TODO: Temporary fix since Tensor.numel() not supported in distributed mode + per_device_trainable_numel = sum(np.prod(p.shape) for p in model.parameters() if not p.stop_gradient) + logger.info(f" Number of trainable parameters = {per_device_trainable_numel:,} (per device)") + if self.args.use_hybrid_parallel: + # todo fix for pipeline_parallel_degree + parts_num = max(self.args.tensor_parallel_degree, 1) * max(self.args.pipeline_parallel_degree, 1) + if parts_num > 1: + all_reduce_dtype = "int64" + if paddle.get_device().split(":")[0] in ["npu", "xpu"]: + # TODO(duanyanhui): fix when NPU all_reduce supports int64 + all_reduce_dtype = "float32" + trainable_numel_tensor = paddle.to_tensor(per_device_trainable_numel, dtype=all_reduce_dtype) + paddle.distributed.all_reduce(trainable_numel_tensor) + trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size + # the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited + # so, the trainable numel is a little bigger than real. + logger.info(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)") + + +def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs): + """ + Just a copy of single training step complete code in Trainer.train while loop + which including forward+backward+step, while wraps the inputs and outputs to + make the complicated copied code no need to change. Maybe a better way is to + add fine-grained methods including these steps to Trainer which is similar to + DeepSpeed engine. + """ + + # TODO(guosheng): step, steps_trained_in_current_epoch and steps_trained_progress_bar + # should use reference since they would be overwrite. + # for state update + epoch = kwargs.get("epoch", 0) + step = kwargs.get("step", 0) + steps_in_epoch = kwargs.get("steps_in_epoch", 0) + step_control = kwargs.get("step_control", 0) + # for step and progress update when resuming data + train_dataloader = kwargs.get("train_dataloader", None) + resume_from_checkpoint = kwargs.get("resume_from_checkpoint", None) + steps_trained_in_current_epoch = kwargs.get("steps_trained_in_current_epoch", 0) + steps_trained_progress_bar = kwargs.get("steps_trained_progress_bar", None) + # for eval output ignore to gather + ignore_keys_for_eval = kwargs.get("ignore_keys_for_eval", None) + tr_loss = kwargs.get("tr_loss", 0.0) + model = kwargs.get("model", self.model_wrapped) + # needed in _maybe_log_save_evaluate + self._globalstep_last_logged = getattr(self, "_globalstep_last_logged", 0) + self._globalstep_last_start_time = getattr(self, "_globalstep_last_start_time", time.time()) + + args = self.args + + if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1: + inputs = split_inputs_sequence_dim(inputs) + # self.timers and self.timers("read-data").stop() + os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step) + self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs) + + # Skip past any already trained steps if resuming training + # for paddlenlp.utils.batch_sampler.DistributedBatchSampler + # We use consumed_samples to reset the status + if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance( + train_dataloader.batch_sampler, NlpDistributedBatchSampler + ): + if step == 0: + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(steps_trained_in_current_epoch) + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + self._load_rng_state(resume_from_checkpoint) + step += steps_trained_in_current_epoch + elif steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + # continue + final_local_vars = locals() + for k in kwargs.keys(): + if k in final_local_vars: + kwargs[k] = final_local_vars[k] + return kwargs + elif steps_trained_progress_bar is not None: + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + + if step_control % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + self.timers and self.timers("forward-backward").start() + + dp_enabled = self.args.data_parallel_degree > 1 if self.args.use_hybrid_parallel else args.local_rank != -1 + forbidden_no_sync = False + # stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API + # hybrid_parallel (tp or pp or sharding stage 1) should not no_sync + if self.args.use_hybrid_parallel: + forbidden_no_sync = True + + availiable_no_sync = dp_enabled and not forbidden_no_sync + + is_no_sync = ( + ((step_control + 1) % args.gradient_accumulation_steps != 0) + and availiable_no_sync + and args._no_sync_in_gradient_accumulation + ) or (args.recompute and availiable_no_sync) + # sharding + # stage1. the same as ddp + # stage2. manualy collect gradient on dp group + + dp_master_grad = self.args.world_size > 1 and self.args.amp_master_grad and not self.args.use_hybrid_parallel + if dp_master_grad: + is_no_sync = True + + if is_no_sync: + # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. + with model.no_sync(): + tr_loss_step = self.training_step(model, inputs) + else: + tr_loss_step = self.training_step(model, inputs) + + tr_loss += tr_loss_step + + if (step_control + 1) % args.gradient_accumulation_steps == 0 or ( + # last step in epoch but step is always smaller than gradient_accumulation_steps + steps_in_epoch <= args.gradient_accumulation_steps + and (step + 1) == steps_in_epoch + ): + if self.args.pipeline_parallel_degree <= 1 and self._enable_delay_scale_loss(): + tr_loss /= self.args.gradient_accumulation_steps + + self.timers and self.timers("forward-backward").stop() + # Maunally collect gradients + # Case 1: Use recompute and dp + # Case 2: Hack dp with master_grad + # Case 3: Pipeline or sharding overlap + # local_rank != -1 don't means dp in networks. + self.timers and self.timers("all-reduce").start() + + # Case 1: Use recompute and dp / sharding stage1, + # manualy collect gradient for dp. + if args.recompute and availiable_no_sync: + fused_allreduce_gradients(list(model.parameters()), None) + + # Case 2: hack dp with master_grad + if dp_master_grad and not (args.recompute and availiable_no_sync): + fused_allreduce_gradients(list(model.parameters()), None) + + # Pipeline parallel mode, handle gradient reduce here to overlap + pipeline_parallel_config = ( + set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set() + ) + enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config + enable_release_grads = "enable_release_grads" in pipeline_parallel_config + + # Case 3: Pipeline parallel mode, overlap with dp + if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling: + parameters_list = _obtain_optimizer_parameters_list(self.optimizer._inner_opt) + + if not enable_dp_comm_overlap: + if self.optimizer._sharding_enable: + assert reshard_util.is_sharding_opt(self.optimizer) + self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg) + + if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False): + fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg) + + self.timers and self.timers("all-reduce").stop() + self.timers and self.timers("optimizer-step").start() + + if self.args.gradient_accumulation_steps > 1 and self._enable_delay_scale_loss(): + for p in model._layers.parameters(): + with paddle.no_grad(): + if hasattr(p, "main_grad") and p.main_grad is not None: + assert p.grad is None + p.main_grad.scale_(1.0 / self.args.gradient_accumulation_steps) + elif p.grad is not None: + p.grad.scale_(1.0 / self.args.gradient_accumulation_steps) + + # Optimizer step + self.callback_handler.on_optimizer_begin( + args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None + ) + optimizer_was_run = True + if self.do_grad_scaling: + scale_before = paddle.assign(self.scaler._scale) + self.scaler.step(self.optimizer) + self.scaler.update() + scale_after = self.scaler._scale + # Compatible with paddlepaddle 2.6.0 using typo word. + if hasattr(self.scaler, "_cache_founf_inf"): + optimizer_was_run = not self.scaler._cache_founf_inf + else: + optimizer_was_run = not self.scaler._cache_found_inf + if not optimizer_was_run: + scale_before_value = scale_before.cpu().numpy() + scale_after_value = scale_after.cpu().numpy() + logger.warning( + f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}" + ) + elif isinstance(self.optimizer, HybridParallelOptimizer): + self.optimizer._step(parameters_list) + else: + self.optimizer.step() + + self.timers and self.timers("optimizer-step").stop() + + if optimizer_was_run: + self.lr_scheduler.step() + + if enable_release_grads and args.pipeline_parallel_degree > 1: + self.optimizer.clear_grad(set_to_zero=False) + for _, buffers in model._chunk_2_comm_buffers.items(): + for buffer in buffers: + buffer._clear_grad_storage() + else: + self.optimizer.clear_grad(set_to_zero=False) + + self.callback_handler.on_optimizer_end( + args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None + ) + + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs) + self._print_timer() + step_control = 0 + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + step_control += 1 + + if self.control.should_epoch_stop or self.control.should_training_stop: + # break + final_local_vars = locals() + for k in kwargs.keys(): + if k in final_local_vars: + kwargs[k] = final_local_vars[k] + return kwargs + # self.timers and self.timers("read-data").start() + + final_local_vars = locals() + for k in kwargs.keys(): + if k in final_local_vars: + kwargs[k] = final_local_vars[k] + return kwargs + + +Trainer.init_train_model_opt = init_train_model_opt +Trainer.init_train_log = init_train_log +Trainer.init_train_state = init_train_state +Trainer.full_training_step = full_training_step +# ########## patches for Trianer ########## + + +class MuteDefaultFlowCallback(TrainerCallback): + """ + Add this callback can cencel logging/evaluation/saving by DefaultFlowCallback. + Use this when having multi trainer. + """ + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + control.should_save = False + control.should_evaluate = False + control.should_log = False + return control + + +@contextmanager +def guard_set_args(args, arg_name_values): + for k, v in arg_name_values.items(): + old_value = getattr(args, k, None) + setattr(args, k, v) + arg_name_values[k] = old_value + yield + for k, v in arg_name_values.items(): + old_value = getattr(args, k) + setattr(args, k, v) + arg_name_values[k] = old_value + + +class PipeEvalModel(GenerationMixin): + """ + Wrapper for PipelineParallel to do evaluate and generate. Currently only + support . + """ + + def __init__(self, trainer: Trainer): + eval_model = getattr(trainer, "_inner_eval_model", None) + self.model: fleet.model.PipelineParallel = trainer.model_wrapped if eval_model is None else eval_model + self.config: PretrainedConfig = trainer.model.config + self._is_gen = False + self.update_model_kwargs_for_generation = ( + self.model._layers._non_pipe_model_class.update_model_kwargs_for_generation + ) + + @property + def pp_group(self): + return self.model.pp_group + + def eval(self): + self.model.eval() + + def train(self): + self.model.train() + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.model, name) + + def _broadcast_outputs(self, outputs): + # outputs is PipelineParallel.eval_batch which is a list of batches. + out = [] + outputs = (outputs,) if isinstance(outputs, paddle.Tensor) else outputs + for tensors in outputs: + if not self.model.is_pipeline_last_stage(): + tensor = tensors if isinstance(tensors, paddle.Tensor) else tensors[0] + head_out_meta = ( + (self.model._layers.head_out_meta,) + if isinstance(self.model._layers.head_out_meta, paddle.static.InputSpec) + else self.model._layers.head_out_meta + ) + tensors = tuple( + paddle.empty( + shape=[ + tensor.shape[i] if (meta.shape[i] is None or meta.shape[i] < 0) else meta.shape[i] + for i in range(len(meta.shape)) + ], + dtype=tensor.dtype if meta.dtype is None else meta.dtype, + ) + for meta in head_out_meta + ) + else: + # Currently use tuple instead of ModelOutput and require the + # caller use the return result as tuple. + tensors = ( + (tensors,) + if isinstance(tensors, paddle.Tensor) + else tensors.to_tuple() + if isinstance(tensors, ModelOutput) + else tensors + ) + + # use map_structure seems hung + for tensor in tensors: + paddle.distributed.broadcast(tensor, src=self.model.pp_group.ranks[-1], group=self.model.pp_group) + out.append(tensors[0] if len(tensors) == 1 else tensors) + return out[0] if len(out) == 1 else out + + def __call__(self, *args, **kwargs): + model = self.model + assert self.model.training is False + if self._is_gen: + # inputs by `prepare_inputs_for_generation` is a dict with following keys: + # "input_ids", "position_ids", "past_key_values", "use_cache", "attention_mask" + # NOTE: 1. cache/past_key_values should be passed across decoding steps + # by using as model attr rather than input args to reduce comm overhead. + # Also, pipe model defined for training not support this cache input. + # 2. ignore use_cache since _check_data_vaild requires tensor if not None. + # 3. attention_mask can reuse _prepare_decoder_attention_mask in LlamaEmbeddingPipe. + # 4. position_ids pass through _prepare_pipeline_inputs_func and PipeLayer. + inputs, labels = model._prepare_pipeline_inputs_func(*args, **kwargs) + # currently, set accumulate_steps to 1 to avoid multi-batch eval/gen + with guard_set_args(model, {"_compute_loss": False, "accumulate_steps": 1}): + outputs = model.eval_batch([inputs, labels], compute_loss=False) + # TODO(guosheng): Broadcasted logits are used to get next_scores, remove + # it to reduce comm overhead. Also note that we still need broadcast + # next_tokens though logits are broadcasted since pp ranks' seeds differs. + # Currently, just slice the last token to reduce comm overhead. + outputs = [ + micro_batch_output[:, -1, :].unsqueeze(1) + if isinstance(micro_batch_output, paddle.Tensor) + else micro_batch_output[0][:, -1, :].unsqueeze(1) + for micro_batch_output in outputs + ] + outputs = self._broadcast_outputs(outputs) + else: + # use _prepare_pipeline_inputs_func to convert pipeline inputs + inputs, labels = model._prepare_pipeline_inputs_func(*args, **kwargs) + # NOTE(guosheng): bug seems exist. pp.eval_batch(compute_loss=False) + # will set pp._compute_loss to False and would not set it back. Thus + # hack here to set it back. + with guard_set_args(model, {"_compute_loss": False, "accumulate_steps": 1}): + outputs = model.eval_batch([inputs, labels], compute_loss=False) + outputs = self._broadcast_outputs(outputs) + return outputs + + def generate(self, *args, **kwargs): + self._is_gen = True + # patch DecoderLayerPipe to use cache, DecoderLayerPipe is subclass of + # DecoderLayer, and would call super().forward + ori_decoder_layer_forward = self.model._layers._non_pipe_decoder_layer_class.forward + + def decoder_layer_forward(layer_self, *args, **kwargs): + kwargs.update({"use_cache": True, "past_key_value": getattr(layer_self, "_cache", None)}) + outputs = ori_decoder_layer_forward(layer_self, *args, **kwargs) + output = outputs[0] + layer_self._cache = outputs[1] + self._has_cache = True + return output + + with guard_set_args(self.model._layers._non_pipe_decoder_layer_class, {"forward": decoder_layer_forward}): + outputs = super().generate(*args, **kwargs) + self._is_gen = False + # clear cache of decoder layers, sublayers is incursive thus suitable + # to both 1F1B and interleave + for layer in self.model._layers.sublayers(): + if isinstance(layer, self.model._layers._non_pipe_decoder_layer_class): + layer._cache = None + self._has_cache = False + return outputs + + def prepare_inputs_for_generation(self, *args, **kwargs): + arg_bind = inspect.signature(self.model._layers._non_pipe_model_class.prepare_inputs_for_generation).bind( + *((self,) + args), **kwargs + ) + arg_bind.apply_defaults() + arg_dict = arg_bind.arguments + last_arg_name, last_arg_value = arg_dict.popitem() + if arg_bind.signature.parameters[last_arg_name].kind == inspect.Parameter.VAR_KEYWORD: + arg_dict.update(last_arg_value) + else: + arg_dict[last_arg_name] = last_arg_value + arg_dict.pop("self") + past_key_values = arg_dict.get("past_key_values", None) + # prepare_inputs_for_generation use past_key_values to discrimate prefill + # or decode and slice inputs accordingly. + if getattr(self, "_has_cache", False): + arg_dict.update({"past_key_values": True}) + model_inputs = self.model._layers._non_pipe_model_class.prepare_inputs_for_generation(self, **arg_dict) + model_inputs.update({"past_key_values": past_key_values}) + return model_inputs + + +def is_same_tokenizer( + tokenizer: PretrainedTokenizer, + other_tokenizer: PretrainedTokenizer, +) -> bool: + """Check if two tokenizers are the same.""" + return tokenizer is other_tokenizer or ( + tokenizer.__class__ == other_tokenizer.__class__ and tokenizer.get_vocab() == other_tokenizer.get_vocab() + ) + + +def batch_retokenize( + input_ids: paddle.Tensor, + src_tokenizer: PretrainedTokenizer, + dest_tokenizer: PretrainedTokenizer, + *, + padding: bool | str | PaddingStrategy = PaddingStrategy.LONGEST, + truncation: bool | str | TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + skip_special_tokens: bool = True, +) -> BatchEncoding: + """Re-tokenize a batch of input ids from one tokenizer to another.""" + output = dest_tokenizer( + [ + text + dest_tokenizer.eos_token + for text in src_tokenizer.batch_decode( + input_ids, + skip_special_tokens=skip_special_tokens, + ) + ], + padding=padding, + truncation=truncation, + return_tensors="pd", + ) + return output diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index f22eecb15d19..16f1ed76c23d 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -126,7 +126,7 @@ def __init__(self, config: LlamaConfig): self.quant_type ) - if config.tensor_parallel_degree > 1: + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: self.embed_tokens = fleet.meta_parallel.VocabParallelEmbedding( self.vocab_size, self.hidden_size, diff --git a/paddlenlp/generation/utils.py b/paddlenlp/generation/utils.py index 0391647ab65e..aa2958fc26a1 100644 --- a/paddlenlp/generation/utils.py +++ b/paddlenlp/generation/utils.py @@ -1221,10 +1221,19 @@ def sample( try: hcg = fleet.get_hybrid_communicate_group() group = hcg.get_model_parallel_group() - src = group.get_model_parallel_group_src_rank() + src = hcg.get_model_parallel_group_src_rank() except: group, src = None, 0 paddle.distributed.broadcast(next_tokens, src=src, group=group) + # config does not include pipeline_parallel_degree, and pipeline parallel + # uses trainer.model_wrapped to run in both train and predict mode + # which has pp_group as a attribute + # TODO(guosheng): only let the last stage of pipeline to do softmax + # and sampling, and then broadcast to avoid broadcast logits. + if getattr(self, "pp_group", None) is not None: + paddle.distributed.broadcast( + next_tokens, src=self.pp_group.ranks[0], group=self.pp_group # use rank 0 for same seed to check + ) next_scores = paddle.index_sample(origin_probs, next_tokens) diff --git a/paddlenlp/trainer/plugins/timer.py b/paddlenlp/trainer/plugins/timer.py index 6413a16eb1b4..50bc61d96731 100644 --- a/paddlenlp/trainer/plugins/timer.py +++ b/paddlenlp/trainer/plugins/timer.py @@ -116,6 +116,7 @@ def log(self, names, normalizer=1.0, reset=True): assert normalizer > 0.0 # string = "time (ms) / rate" string = "time (ms)" + names = sorted(list(names)) time_dict = {} for name in names: diff --git a/paddlenlp/trainer/utils/helper.py b/paddlenlp/trainer/utils/helper.py index 25f593f71e35..3e7693aaad60 100644 --- a/paddlenlp/trainer/utils/helper.py +++ b/paddlenlp/trainer/utils/helper.py @@ -25,6 +25,7 @@ from paddle.distributed import fleet from paddlenlp.utils.log import logger +from paddlenlp.utils.nested import nested_broadcast_tensor_with_empty # noqa: F401 from paddlenlp.utils.nested import ( nested_broadcast_tensor, nested_empty_tensor, diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 8f2dd1c36415..637b2b3e7df7 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -362,8 +362,9 @@ def forward(self, hidden_states): if paddle.in_dynamic_mode(): with paddle.amp.auto_cast(False): - hidden_states = hidden_states.astype("float32") - variance = hidden_states.pow(2).mean(-1, keepdim=True) + # hidden_states = hidden_states.astype("float32") + # variance = hidden_states.pow(2).mean(-1, keepdim=True) + variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states else: hidden_states = hidden_states.astype("float32") @@ -1627,7 +1628,11 @@ def __init__(self, config): super(LlamaPretrainingCriterion, self).__init__() self.ignore_index = getattr(config, "ignore_index", -100) self.config = config - self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + self.enable_parallel_cross_entropy = ( + config.tensor_parallel_degree > 1 + and config.vocab_size % config.tensor_parallel_degree == 0 + and config.tensor_parallel_output + ) if self.enable_parallel_cross_entropy: # and False: # and lm_head is distributed self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index) diff --git a/paddlenlp/transformers/llama/modeling_pp.py b/paddlenlp/transformers/llama/modeling_pp.py index a00d8fc01f76..edb262681597 100644 --- a/paddlenlp/transformers/llama/modeling_pp.py +++ b/paddlenlp/transformers/llama/modeling_pp.py @@ -96,7 +96,7 @@ def __init__(self, config): self.config = config self.sequence_parallel = config.sequence_parallel self.hidden_size = config.hidden_size - if config.tensor_parallel_degree > 1: + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: self.embed_tokens = fleet.meta_parallel.VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -183,15 +183,25 @@ def forward(self, args): if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: if attention_mask is not None or alibi is not None: hidden_states = recompute( - super().forward, hidden_states, attention_mask=attention_mask, alibi=alibi, use_reentrant=False + super().forward, + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + alibi=alibi, + use_reentrant=False, ) else: # for pretrain hidden_states = recompute( - super().forward, hidden_states, use_reentrant=self.config.recompute_use_reentrant + super().forward, + hidden_states, + position_ids=position_ids, + use_reentrant=self.config.recompute_use_reentrant, ) else: - hidden_states = super().forward(hidden_states, attention_mask=attention_mask, alibi=alibi) + hidden_states = super().forward( + hidden_states, position_ids=position_ids, attention_mask=attention_mask, alibi=alibi + ) return return_args(hidden_states, attention_mask, position_ids, alibi) @@ -252,7 +262,7 @@ def get_hcg(): f"llama.layers.{i}", ) self.add_sequential_layer(LayerDesc(LlamaRMSNormPipe, config=config), "llama") - self.add_sequential_layer(LayerDesc(LlamaLMHead, config=config), "lm_head") + self.add_head(config) recompute_interval = 0 @@ -263,7 +273,7 @@ def get_hcg(): PipelineLayer.__init__( self, layers=self.get_sequential_layers(), - loss_fn=LlamaPretrainingCriterion(config), + loss_fn=self.get_loss_fn(config), topology=get_hcg().topology(), seg_method=seg_method, recompute_interval=recompute_interval, @@ -278,3 +288,9 @@ def get_hcg(): self.apply(self._init_weights) # DON'T init PipelinePretrainedModel # PipelinePretrainedModel.__init__(self.super(), config=config) + + def add_head(self, config): + self.add_sequential_layer(LayerDesc(LlamaLMHead, config=config), "lm_head") + + def get_loss_fn(self, config): + return LlamaPretrainingCriterion(config) diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index f2e79a47c565..f6032f0cf1f8 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -2589,7 +2589,8 @@ def _set_pipeline_name_mapping(self, mappings=None): idx = name_splited[0] # for normal pp layer if idx.isdigit(): - single_name = [prefixes[idx]] + # allow empty prefix + single_name = [] if prefixes[idx] == "" else [prefixes[idx]] single_name.extend(name_splited[1:]) elif idx == "shared_layers": single_name = [self.get_shardlayer_prefix(name_splited)] diff --git a/paddlenlp/utils/nested.py b/paddlenlp/utils/nested.py index 27942b8cb256..4e800231843c 100644 --- a/paddlenlp/utils/nested.py +++ b/paddlenlp/utils/nested.py @@ -17,6 +17,8 @@ import paddle +from paddlenlp.utils.log import logger + TensorHolder = collections.namedtuple("TensorHolder", ["shape", "dtype", "name"]) @@ -63,6 +65,39 @@ def nested_broadcast_tensor(tensor, src=0, group=None): return tensor +def nested_broadcast_tensor_with_empty(tensor, src=0, group=None): + # src should src rank in the group, not global rank. + process_rank = paddle.distributed.get_rank() + + if group is not None: + src_rank = group.ranks[src] + if process_rank == src_rank: + if tensor is None: + logger.warning( + f"Your local rank {paddle.distributed.get_rank()} must have a state_dict. dp_rank:{process_rank}, src_rank:{src_rank}" + ) + fake_tensor = [nested_reduce_tensor(tensor)] + else: + if tensor is not None: + logger.warning( + f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict. dp_rank:{process_rank}, src_rank:{src_rank}" + ) + fake_tensor = [None] + + paddle.distributed.broadcast_object_list( + fake_tensor, + src=src_rank, + group=group, + ) + fake_tensor = fake_tensor[0] + + if process_rank != src_rank: + tensor = nested_empty_tensor(fake_tensor) + + tensor = nested_broadcast_tensor(tensor, src=src_rank, group=group) + return tensor + + def nested_copy(inputs): if isinstance(inputs, dict): outputs = {}