Skip to content

Commit bf6d4e7

Browse files
guoshengCSZHUIgongel
authored
Add Pipeline Parallel for PPO training and support generation with InferenceModel (#7953)
* Add Pipeline Parallel for PPO training. * Move new_ppo_trainer.py to ppo_trainer.py * Fix padding among batches of accumulation steps in _prepare_pipeline_inputs_func. * Fix hcg using in TP generation * Try to support generation in PP. And allow extra training args passed from main from_pratrined. * Support PP generation. * Fix PP eval by unify prediction_step * Fix reward value showing error cased by BF16 dtype when eval * fix all * Make non-PipelineParallel models use the same loss layer with PipeModel to unify. * add offload. * Use create_loss to unify Pipe and non-Pipe usage. * Add eval mode and offload level. * merge * support tp+pp * fix data split. * Fix position_ids in generation/eval/train. * fix data group. * add tp rank guard * Support rollout label data both with target length or source+target length. * Move metric calculation to rl_step to avoid comm. * fix pad * fix create group. * no print * Suppport inference model generation. * fix compatible for no eval model. * fix pp sync. * remove debug info * Refacor PPO training using StepTrainer. * Open PolicyTrainer loss logging postprocess. More StepTrainer docs. * more timer. * fix bugs. * Add EMA and PPOMetric * add tests * add unit test for rank guard. * Fix reshard zero3 and reshard infer. * Revert #7818 for llama and remove position_ids for gen/train/eval to align. * Move reload/clean/data_group to comm_utils and use guard to decorate them. * Offload sync and other data reuse fix. * Clead code * Update README * Update ppo_trainer * format code * Fix make_position_ids by 4d causal mask. * Fix nested_broadcast_tensor_with_empty import * Update eval with make_attention_mask --------- Co-authored-by: Zhong Hui <zhonghui.net@gmail.com> Co-authored-by: gongenlei <gongenlei@baidu.com>
1 parent 3d777c1 commit bf6d4e7

28 files changed

+4145
-998
lines changed

examples/RLHF/README.md

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# RLHF PPO
22

3-
提供了基于强化学习 PPO 算法对 LLM 进行人类偏好对齐的代码及完整使用示例。其中 PPO 代码实现细节参考了 [PKU-Alignment/safe-rlhf](https://github.com/PKU-Alignment/safe-rlhf)(PKU Beaver) 中的 PPO 实现,支持reward normalization、pretraining loss等常用的 PPO 稳定训练策略;示例使用 PKU-Alignment/safe-rlhf 提供的部分数据集和模型。后续将持续完善扩展,支持更好效果、更低成本、更高性能、更大规模的 RLHF 能力。
3+
提供了基于强化学习 PPO 算法对 LLM 进行人类偏好对齐的代码及完整使用示例,支持**3D 分布式并行训练以及 rollout 阶段使用预测优化进行生成加速**。其中 PPO 代码实现细节参考了 [PKU-Alignment/safe-rlhf](https://github.com/PKU-Alignment/safe-rlhf)(PKU Beaver) 中的 PPO 实现,支持reward normalization、pretraining loss等常用的 PPO 稳定训练策略;示例使用 PKU-Alignment/safe-rlhf 提供的部分数据集和模型。后续将持续完善扩展,支持更好效果、更低成本、更高性能、更大规模的 RLHF 能力。
44

55
## 快速开始
66

@@ -14,6 +14,9 @@
1414
├── ppo_main.py # RLHF训练脚本
1515
├── ppo_config.json # RLHF训练配置文件
1616
├── ppo_trainer.py # RLHF训练执行器py脚本
17+
├── ppo_config.json # RLHF训练配置文件
18+
├── trainer_utils.py # Trainer补丁及工具py脚本
19+
├── infer_utils.py # 生成加速工具py脚本
1720
├── data # 数据集相关目录
1821
│ └── base.py # 数据集基类及工具py文件
1922
│ └── alpaca.py # alpaca(raw)数据集py文件
@@ -24,16 +27,20 @@
2427
├── models # 模型相关目录
2528
│ └── score_model_utils.py # score model基类及工具py文件
2629
│ └── score_model.py # score model模型定义py文件
30+
│ └── ppo_model_utils.py # PPO loss等模型策略py文件
31+
│ └── pp_model_utils.py # 流水线并行补丁及工具py文件
32+
│ └── model_pp.py # 流水线并行模型py文件
33+
│ └── infer_model_utils.py # 预测加速模型补丁及工具py文件
2734
└── README.md
2835
```
2936

3037
### 环境准备
3138

3239
- Python >= 3.10
3340
- PaddlePaddle >= 2.6.0
34-
- PaddleNLP >= 2.6.0
41+
- PaddleNLP 最新版本
3542

36-
此外还需要安装以下依赖:`pip install rich`
43+
如需使用生成加速功能,需要安装 [paddlenlp_ops](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/csrc) ,请使用 `git clone https://github.com/PaddlePaddle/PaddleNLP.git` 克隆 PaddleNLP 代码库并且将 PaddleNLP/llm 目录的路径加入 PYTHONPATH(后续将进行完善)。安装 paddlenlp_ops 后训练时将直接开启生成加速(开启流水线并行时不支持生成加速),否则使用原生动态图进行生成。
3744

3845
### 数据准备
3946

@@ -184,7 +191,8 @@ python -u -m paddle.distributed.launch reward_main.py ./reward_config.json
184191
RLHF 阶段需要 actor model、reference model、critic model、reward model 四个模型;actor-model/reference-model 使用 SFT 模型进行 initialize/frozen;critic-model/reward-model 使用 reward 模型进行 initialize/frozen (另外注意若 SFT 使用 LoRA 请先将 LoRA 权重合并)。这里使用 PKU-Alignment/PKU-SafeRLHF 提供的 SFT 模型([PKU-Alignment/alpaca-7b-reproduced](https://huggingface.co/PKU-Alignment/alpaca-7b-reproduced))和 reward 模型([PKU-Alignment/beaver-7b-v1.0-reward](https://huggingface.co/PKU-Alignment/beaver-7b-v1.0-reward),注意该模型只关注 helpful 未考量 harmless)作为示例,使用 `ppo_main.py` 脚本根据 `ppo_config.json` 进行 RLHF 训练。
185192

186193
```
187-
python -u -m paddle.distributed.launch ppo_main.py ./ppo_config.json
194+
# 类型提升 warning 暂时通过 loglevel 屏蔽,待后续修复
195+
GLOG_minloglevel=2 python -u -m paddle.distributed.launch ppo_main.py ./ppo_config.json
188196
```
189197

190198
`ppo_config.json` 中的绝大部分参数释义同[LLM 精调](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm#2-%E7%B2%BE%E8%B0%83),不再赘述,重点给出以下参数配置及释义(使用 PKU-Alignment/PKU-SafeRLHF 中的默认值):
@@ -210,7 +218,15 @@ python -u -m paddle.distributed.launch ppo_main.py ./ppo_config.json
210218

211219
另外所有 [`TrainingArguments` 支持参数配置](https://paddlenlp.readthedocs.io/zh/latest/trainer.html#trainingarguments)将为 actor-model 和 critic-model 的训练复用(如`sharding_stage`),除单独提供了 `critic_learning_rate/critic_weight_decay/critic_lr_scheduler_type/critic_warmup_ratio/critic_recompute` 这些参数支持为 critic-model 训练单独指定相应配置。actor-model 和 critic-model 的 checkpoints 将分别保存在 `outpt_dir` 所指定目录的 policy 和 value 文件夹下。
212220

213-
当前示例中所用数据及规模 RLHF 训练基于 sharding stage3 使用 NVIDIA A100 80G 4卡/8卡训练验证。
221+
此外为了支持更高性、更大规模的 RLHF 训练提供了以下特殊参数配置,可以按需使用:
222+
- `use_fusemt`:安装 paddlenlp_ops 后将在 rollout 生成时开启生成加速(开启流水线并行时不支持生成加速),通过此设置可以禁用生成加速。
223+
- `eval_mode`:支持为空或者设置为 "single"、"tensor_parallel";通常可以在使用流水线并行训练时设置为"tensor_parallel",以此在 rollout 生成阶段使用非流水线并行模型并进行生成加速。
224+
- `offload_level`:支持设置为"freeze_model"、"optimizer"、"train_model"或者同时使用(空格分隔),分别指示 reward+reference 两个冻结模型、actor+critic 两个训练模型的优化器状态和模型参数的 offload/reload,用于在不同阶段 model/optimizer 使用结束后及时 offload 并在下次使用时 reload 相应参数权重以节省显存。
225+
226+
另外注意,在使用流水线并行时(pipeline_parallel_degree大于1)建议将 `dataloader_drop_last` 设置为 true, 以此避免不同batch size带来的问题。
227+
228+
229+
214230

215231
### 推理
216232

0 commit comments

Comments
 (0)