Skip to content

Commit d4de12c

Browse files
authored
Add PPO training. (#7305)
* Add reward model and training. * Make reward training runable * Add eval in reward training. * For result alignment. * training setting. * PPO alignment with Beaver * Align training with Beaver. * Clean ppo_trainer.py and debug print. * Move score models from paddlenlp to example and add AutoModelForScore. * Remove eval.py, eval_score.py. Fix AutoModelForScore and update reward training usage. * Update ppo_trainer.py after merge with develop. * Make PPOTrainer support reference/reward Trainer optionally. * Complete README. * Add unittest test_load_from_custom_arch for AutoConfig * Add test_synced_gpus.py for generation. * Add more test cases in test_synced_gpus.py * Support tensor parallel. * Add require_gpu to test_synced_gpus.py.
1 parent e3fc63a commit d4de12c

23 files changed

+4485
-12
lines changed

examples/RLHF/README.md

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# RLHF PPO
2+
3+
提供了基于强化学习 PPO 算法对 LLM 进行人类偏好对齐的代码及完整使用示例。其中 PPO 代码实现细节参考了 [PKU-Alignment/safe-rlhf](https://github.com/PKU-Alignment/safe-rlhf)(PKU Beaver) 中的 PPO 实现,支持reward normalization、pretraining loss等常用的 PPO 稳定训练策略;示例使用 PKU-Alignment/safe-rlhf 提供的部分数据集和模型。后续将持续完善扩展,支持更好效果、更低成本、更高性能、更大规模的 RLHF 能力。
4+
5+
## 快速开始
6+
7+
项目整体组织结构如下:
8+
9+
```
10+
.
11+
├── reward_main.py # reward model训练脚本
12+
├── reward_config.json # reward model训练配置文件
13+
├── reward_trainer.py # reward训练执行器py脚本
14+
├── ppo_main.py # RLHF训练脚本
15+
├── ppo_config.json # RLHF训练配置文件
16+
├── ppo_trainer.py # RLHF训练执行器py脚本
17+
├── data # 数据集相关目录
18+
│ └── base.py # 数据集基类及工具py文件
19+
│ └── alpaca.py # alpaca(raw)数据集py文件
20+
│ └── safe_rlhf.py # safe_rlhf(raw)数据集py文件
21+
│ └── preference.py # 偏好数据集py文件
22+
│ └── prompt_only.py # prompt only数据集py文件
23+
│ └── supervised.py # supervised数据集py文件
24+
├── models # 模型相关目录
25+
│ └── score_model_utils.py # score model基类及工具py文件
26+
│ └── score_model.py # score model模型定义py文件
27+
└── README.md
28+
```
29+
30+
### 环境准备
31+
32+
- Python >= 3.10
33+
- PaddlePaddle >= 2.6.0
34+
- PaddleNLP >= 2.6.0
35+
36+
此外还需要安装以下依赖:`pip install rich`
37+
38+
### 数据准备
39+
40+
PPO 训练包括 Supervised Fine-Tuning、Reward Model Fine-Tuning、RLHF 三个阶段(可见下文训练部分),会涉及到多个数据集,下面分别介绍并给出自定义数据的方法。
41+
42+
#### Supervised Fine-Tuning 数据
43+
[LLM 精调](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm#2-%E7%B2%BE%E8%B0%83),可以直接参考对应内容进行数据准备。
44+
45+
#### Reward Model Fine-Tuning 数据
46+
Reward Model Fine-Tuning 阶段需要使用人类偏好数据。示例使用 PKU-Alignment/safe-rlhf 提供的 [PKU-Alignment/PKU-SafeRLHF-30K](https://huggingface.co/datasets/PKU-Alignment/PKU-SafeRLHF-30K) 数据集,下面是其中一条样本,这里使用其中的`prompt、response_0、response_1、better_response_id`字段来组织偏好数据(safe字段,该数据集将helpful 和 harmless 分开标注,示例这里使用其 helpful 标注)。
47+
48+
```json
49+
{
50+
"prompt": "Is it possible to purchase a complete chess set at a pawn shop?",
51+
"response_0": "Yes, it is possible to purchase a complete chess set at a pawn shop.",
52+
"response_1": "Yes, it is possible to purchase a complete chess set at a pawn shop. Pawn shops may carry chess sets that have been either inherited or previously borrowed. There is also a possibility of finding individual chess pieces such as rooks, knights, bishops, queens, and kings. If you cannot find a complete set, you may be able to purchase the various chess pieces separately.",
53+
"is_response_0_safe": true,
54+
"is_response_1_safe": true,
55+
"better_response_id": 1,
56+
"safer_response_id": 1
57+
}
58+
```
59+
60+
#### RLHF 数据
61+
RLHF 阶段使用 prompt only 数据,另外可以可选的提供额外的监督数据用于构建 LM 损失约束 RLHF 训练。示例使用 [PKU-Alignment/PKU-SafeRLHF](https://huggingface.co/datasets/PKU-Alignment/PKU-SafeRLHF) 数据集(同样是人类偏好数据集,这里只使用其 prompt 字段并对 prompt 去重)。此外还使用了 [tatsu-lab/alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca)的数据来构建额外的损失项。
62+
63+
以上示例数据集在训练时将自动下载缓存使用。
64+
65+
#### 自定义数据
66+
数据定义围绕 `RawSample``RawDataset` 两个预置的类进行;`RawSample` 提供了数据样本级别接入协议规范,`RawDataset` 提供了数据集级别接入协议规范;按照下面介绍的规范接入,即可通过预置的 `SupervisedDataset``PreferenceDataset``PromptOnlyDataset` 三类 RLHF 训练所需数据接口来使用自定义数据。
67+
68+
自定义数据集需要:
69+
- 继承 `RawDataset` ,并定义类属性 `NAME` 用于注册数据集。
70+
- 实现 `__init__` 方法(加载数据),`__getitem__` 方法(根据 index 获取样本并转换为 `RawSample` 对象返回)、`__len__` 方法(数据集大小)。
71+
72+
示例如下:
73+
74+
```python
75+
from datasets import load_dataset
76+
from data import RawDataset, RawSample
77+
78+
class MyRawDataset(RawDataset):
79+
NAME = 'my-dataset-name'
80+
81+
def __init__(self, path=None) -> None:
82+
# Load a dataset from Hugging Face or any other data source
83+
# self.data = load_dataset(path or 'my-organization/my-dataset')['train']
84+
self.data = [{
85+
'col1': 'question',
86+
'col2': 'answer1',
87+
'col3': 'answer2',
88+
'col4': 1, # score of answer1
89+
'col5': 2 # score of answer2
90+
}] * 10 # dummy data for example
91+
92+
def __getitem__(self, index: int) -> RawSample:
93+
data = self.data[index]
94+
# Construct a `RawSample` dictionary from your custom dataset item
95+
return RawSample(
96+
input=data['col1'],
97+
answer=data['col2'],
98+
other_answer=data['col3'],
99+
better=float(data['col4']) > float(data['col5']),
100+
)
101+
102+
def __len__(self) -> int:
103+
return len(self.data) # dataset size
104+
```
105+
106+
其中 `RawSample` 是整个 RLHF 训练过程用到的几种数据类型的超集,如下所示,其可以桥接各训练阶段所需样本类型。在自定义数据时,对于 SFT 数据使用 `RawSample``(input, answer)` 字段;对于人类偏好数据使用 `RawSample``(input, answer, other_answer, better)` 字段;对于 prompt only 数据,使用 `RawSample``(input)`字段。
107+
108+
```python
109+
class RawSample(TypedDict, total=False):
110+
"""Raw sample type.
111+
112+
For SupervisedDataset, should provide (input, answer) or (dialogue).
113+
For PreferenceDataset, should provide (input, answer, other_answer, better).
114+
For SafetyPreferenceDataset, should provide (input, answer, other_answer, safer, is_safe, is_other_safe).
115+
For PromptOnlyDataset, should provide (input).
116+
117+
When input is a list, it would be processed as a dialogue.
118+
"""
119+
120+
# Texts
121+
input: NotRequired[str | list[str]] # either `input` or `dialogue` should be provided
122+
"""User input text."""
123+
answer: NotRequired[str]
124+
"""Assistant answer text."""
125+
other_answer: NotRequired[str]
126+
"""Other assistant answer text via resampling."""
127+
dialogue: NotRequired[list[str]] # either `input` or `dialogue` should be provided
128+
"""Dialogue history."""
129+
130+
# Flags
131+
better: NotRequired[bool]
132+
"""Whether ``answer`` is better than ``other_answer``."""
133+
safer: NotRequired[bool]
134+
"""Whether ``answer`` is safer than ``other_answer``."""
135+
is_safe: NotRequired[bool]
136+
"""Whether ``answer`` is safe."""
137+
is_other_safe: NotRequired[bool]
138+
"""Whether ``other_answer`` is safe."""
139+
```
140+
141+
如此定义的数据集将可以通过预置接口根据 `NAME` 来使用,当前内置支持`"PKU-SafeRLHF/train", "PKU-SafeRLHF/test", "PKU-SafeRLHF-30K/train", "PKU-SafeRLHF-30K/test", "PKU-SafeRLHF-10K/train", "alpaca"` 几个数据集。另外还支持使用多个数据集并指定数据比例,我们可以按照需要为每个阶段训练准备多份数据集。示例如下:
142+
143+
```python
144+
from paddlenlp.transformers import AutoTokenizer
145+
from data import PreferenceDataset
146+
147+
tokenizer = AutoTokenizer.from_pretrained('facebook/llama-7b')
148+
dataset = PreferenceDataset({
149+
'alpaca': 0.75,
150+
'my-dataset-name': 0.5
151+
}, tokenizer)
152+
```
153+
154+
### 训练
155+
156+
PPO 完整的训练过程包括以下 3 个阶段,如下图所示(来自[DeepSpeed-Chat](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat)):
157+
158+
<p align="center">
159+
<img src="https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/assets/image/ppo_trainer.png?raw=true" align="middle" width = "600" />
160+
</p>
161+
162+
1. Supervised Fine-Tuning (SFT)
163+
164+
[LLM 精调](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm#2-%E7%B2%BE%E8%B0%83),可以直接参考对应内容进行训练并使用其产出模型。
165+
166+
2. Reward Model Fine-Tuning
167+
168+
使用 `reward_main.py` 脚本根据 `reward_config.json` 训练奖励模型
169+
170+
```
171+
python -u -m paddle.distributed.launch reward_main.py ./reward_config.json
172+
```
173+
174+
`reward_config.json` 中的绝大部分参数释义同[LLM 精调](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm#2-%E7%B2%BE%E8%B0%83),不再赘述;稍有区别的是 `train_datasets`/`eval_datasets` 分别使用数据集定义注册时的`NAME`属性给出训练和验证集。另外对于奖励模型训练有以下特殊参数配置及释义(使用 PKU-Alignment/PKU-SafeRLHF 中的默认值):
175+
176+
- `normalize_score_during_training`:是否在训练过程中对奖励进行 normalize,默认为 `False`
177+
- `normalizer_type`:使用 normalizer 时计算 mean、var 的方式,可选`"RunningMeanStd", "ExponentialMovingAverage"`
178+
- `normalizer_momentum`:使用 `ExponentialMovingAverage` normalizer 时指定的 momentum ,默认为 `0.9`
179+
- `loss_type`:使用 token 级或是 sequence 级 loss 进行奖励模型训练,可选`"token-wise", "sequence-wise"`,默认为 `"sequence-wise"`
180+
- `regularization`:奖励模型训练目标中对奖励的正则化系数,默认为 `0.001`
181+
182+
3. RLHF:
183+
184+
RLHF 阶段需要 actor model、reference model、critic model、reward model 四个模型;actor-model/reference-model 使用 SFT 模型进行 initialize/frozen;critic-model/reward-model 使用 reward 模型进行 initialize/frozen (另外注意若 SFT 使用 LoRA 请先将 LoRA 权重合并)。这里使用 PKU-Alignment/PKU-SafeRLHF 提供的 SFT 模型([PKU-Alignment/alpaca-7b-reproduced](https://huggingface.co/PKU-Alignment/alpaca-7b-reproduced))和 reward 模型([PKU-Alignment/beaver-7b-v1.0-reward](https://huggingface.co/PKU-Alignment/beaver-7b-v1.0-reward),注意该模型只关注 helpful 未考量 harmless)作为示例,使用 `ppo_main.py` 脚本根据 `ppo_config.json` 进行 RLHF 训练。
185+
186+
```
187+
python -u -m paddle.distributed.launch ppo_main.py ./ppo_config.json
188+
```
189+
190+
`ppo_config.json` 中的绝大部分参数释义同[LLM 精调](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm#2-%E7%B2%BE%E8%B0%83),不再赘述,重点给出以下参数配置及释义(使用 PKU-Alignment/PKU-SafeRLHF 中的默认值):
191+
192+
- `train_datasets`:使用数据集定义注册时的`NAME`属性给出训练集。
193+
- `eval_datasets`:使用数据集定义注册时的`NAME`属性给出验证集。
194+
- `ptx_datasets`:使用数据集定义注册时的`NAME`属性给出 ptx-loss 使用的数据集,未提供时将不使用 ptx-loss。
195+
- `actor_model_name_or_path`:actor-model/reference-model 用来 initialize/frozen 的模型名称或目录。
196+
- `reward_model_name_or_path`:reward-model 的模型名称或目录。
197+
- `reward_critic_model_name_or_path`:critic-model 的模型名称或目录,未提供时将使用`reward_model_name_or_path`进行 critic-model 的初始化。
198+
- `per_device_prompt_batch_size`:训练时 prompt only 数据集读取用于 rollout 生成的批次大小(每张卡)。
199+
- `per_device_train_batch_size`:根据 prompt 进行生成及训练使用的批次大小(每张卡)。
200+
- `num_return_sequences`:生成时每个 prompt 生成的回复个数,即 `GenerationConfig.num_return_sequences`,所有回复都将用来训练。
201+
- `temperature`:生成采样时使用的 `temperature` ,即 `GenerationConfig.temperature`
202+
- `top_p`:生成采样时 top-p-filtering 阈值,即 `GenerationConfig.top_p`
203+
- `repetition_penalty`:生成采样时长度惩罚系数,即 `GenerationConfig.repetition_penalty`
204+
- `update_iters`:一次生成的数据被使用的次数。
205+
- `kl_coeff`:对 reward 进行 KL-Penalty 的系数。
206+
- `clip_range_score`:对 reward 进行裁剪的阈值。
207+
- `clip_range_value`:critic model(value function)对当前sequence的新值与Experience Buffer中旧值的差距超过该范围将进行裁剪。
208+
- `clip_range_ratio`:将当前sequence的新概率与Experience Buffer中旧概率比值裁剪到`(1-clip_range_ratio, 1+clip_range_ratio)`范围(PPO-Clip)。
209+
- `ptx_coeff`: 预训练损失项 ptx-loss 的系数。
210+
211+
另外所有 [`TrainingArguments` 支持参数配置](https://paddlenlp.readthedocs.io/zh/latest/trainer.html#trainingarguments)将为 actor-model 和 critic-model 的训练复用(如`sharding_stage`),除单独提供了 `critic_learning_rate/critic_weight_decay/critic_lr_scheduler_type/critic_warmup_ratio/critic_recompute` 这些参数支持为 critic-model 训练单独指定相应配置。actor-model 和 critic-model 的 checkpoints 将分别保存在 `outpt_dir` 所指定目录的 policy 和 value 文件夹下。
212+
213+
当前示例中所用数据及规模 RLHF 训练基于 sharding stage3 使用 NVIDIA A100 80G 4卡/8卡训练验证。
214+
215+
### 推理
216+
217+
训练完成后可以直接使用 `outpt_dir` 所指定目录中 policy 文件夹下的 checkpoints 按照[LLM 推理](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm#4-%E6%8E%A8%E7%90%86)部分的介绍来进行推理,请参考相应部分内容。
218+
219+
## Acknowledge
220+
221+
我们借鉴了[PKU-Alignment/safe-rlhf](https://github.com/PKU-Alignment/safe-rlhf)(PKU Beaver)的优秀设计实现,在此对其作者表示感谢。
222+
223+
## 参考文献
224+
- Zheng R, Dou S, Gao S, et al. Secrets of rlhf in large language models part i: Ppo[J]. arXiv preprint arXiv:2307.04964, 2023.
225+
- Dai J, Pan X, Sun R, et al. Safe rlhf: Safe reinforcement learning from human feedback[J]. arXiv preprint arXiv:2310.12773, 2023.

examples/RLHF/data/__init__.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from paddle.io import Dataset
16+
17+
from .alpaca import *
18+
from .base import *
19+
from .preference import *
20+
from .prompt_only import *
21+
from .safe_rlhf import *
22+
from .supervised import *
23+
24+
25+
class DummyDataset(Dataset):
26+
def __init__(self, length: int) -> None:
27+
self.length = length
28+
29+
def __len__(self) -> int:
30+
return self.length
31+
32+
def __getitem__(self, index: int):
33+
return {}

examples/RLHF/data/alpaca.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
# Copyright 2023 PKU-Alignment Team. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Stanford Alpaca dataset for supervised instruction fine-tuning."""
16+
17+
from __future__ import annotations
18+
19+
from datasets import load_dataset
20+
21+
from .base import RawDataset, RawSample
22+
23+
__all__ = ["AlpacaDataset"]
24+
25+
26+
class AlpacaDataset(RawDataset):
27+
NAME: str = "alpaca"
28+
ALIASES: tuple[str, ...] = ("stanford-alpaca",)
29+
30+
def __init__(self, path: str | None = None) -> None:
31+
self.data = load_dataset(path or "tatsu-lab/alpaca", split="train")
32+
33+
def __getitem__(self, index: int) -> RawSample:
34+
data = self.data[index]
35+
input = ( # pylint: disable=redefined-builtin
36+
" ".join((data["instruction"], data["input"])) if data["input"] else data["instruction"]
37+
)
38+
answer = data["output"]
39+
return RawSample(input=input, answer=answer)
40+
41+
def __len__(self) -> int:
42+
return len(self.data)

0 commit comments

Comments
 (0)