|
| 1 | +# RLHF PPO |
| 2 | + |
| 3 | +提供了基于强化学习 PPO 算法对 LLM 进行人类偏好对齐的代码及完整使用示例。其中 PPO 代码实现细节参考了 [PKU-Alignment/safe-rlhf](https://github.com/PKU-Alignment/safe-rlhf)(PKU Beaver) 中的 PPO 实现,支持reward normalization、pretraining loss等常用的 PPO 稳定训练策略;示例使用 PKU-Alignment/safe-rlhf 提供的部分数据集和模型。后续将持续完善扩展,支持更好效果、更低成本、更高性能、更大规模的 RLHF 能力。 |
| 4 | + |
| 5 | +## 快速开始 |
| 6 | + |
| 7 | +项目整体组织结构如下: |
| 8 | + |
| 9 | +``` |
| 10 | +. |
| 11 | +├── reward_main.py # reward model训练脚本 |
| 12 | +├── reward_config.json # reward model训练配置文件 |
| 13 | +├── reward_trainer.py # reward训练执行器py脚本 |
| 14 | +├── ppo_main.py # RLHF训练脚本 |
| 15 | +├── ppo_config.json # RLHF训练配置文件 |
| 16 | +├── ppo_trainer.py # RLHF训练执行器py脚本 |
| 17 | +├── data # 数据集相关目录 |
| 18 | +│ └── base.py # 数据集基类及工具py文件 |
| 19 | +│ └── alpaca.py # alpaca(raw)数据集py文件 |
| 20 | +│ └── safe_rlhf.py # safe_rlhf(raw)数据集py文件 |
| 21 | +│ └── preference.py # 偏好数据集py文件 |
| 22 | +│ └── prompt_only.py # prompt only数据集py文件 |
| 23 | +│ └── supervised.py # supervised数据集py文件 |
| 24 | +├── models # 模型相关目录 |
| 25 | +│ └── score_model_utils.py # score model基类及工具py文件 |
| 26 | +│ └── score_model.py # score model模型定义py文件 |
| 27 | +└── README.md |
| 28 | +``` |
| 29 | + |
| 30 | +### 环境准备 |
| 31 | + |
| 32 | +- Python >= 3.10 |
| 33 | +- PaddlePaddle >= 2.6.0 |
| 34 | +- PaddleNLP >= 2.6.0 |
| 35 | + |
| 36 | +此外还需要安装以下依赖:`pip install rich` |
| 37 | + |
| 38 | +### 数据准备 |
| 39 | + |
| 40 | +PPO 训练包括 Supervised Fine-Tuning、Reward Model Fine-Tuning、RLHF 三个阶段(可见下文训练部分),会涉及到多个数据集,下面分别介绍并给出自定义数据的方法。 |
| 41 | + |
| 42 | +#### Supervised Fine-Tuning 数据 |
| 43 | +同[LLM 精调](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm#2-%E7%B2%BE%E8%B0%83),可以直接参考对应内容进行数据准备。 |
| 44 | + |
| 45 | +#### Reward Model Fine-Tuning 数据 |
| 46 | +Reward Model Fine-Tuning 阶段需要使用人类偏好数据。示例使用 PKU-Alignment/safe-rlhf 提供的 [PKU-Alignment/PKU-SafeRLHF-30K](https://huggingface.co/datasets/PKU-Alignment/PKU-SafeRLHF-30K) 数据集,下面是其中一条样本,这里使用其中的`prompt、response_0、response_1、better_response_id`字段来组织偏好数据(safe字段,该数据集将helpful 和 harmless 分开标注,示例这里使用其 helpful 标注)。 |
| 47 | + |
| 48 | +```json |
| 49 | +{ |
| 50 | + "prompt": "Is it possible to purchase a complete chess set at a pawn shop?", |
| 51 | + "response_0": "Yes, it is possible to purchase a complete chess set at a pawn shop.", |
| 52 | + "response_1": "Yes, it is possible to purchase a complete chess set at a pawn shop. Pawn shops may carry chess sets that have been either inherited or previously borrowed. There is also a possibility of finding individual chess pieces such as rooks, knights, bishops, queens, and kings. If you cannot find a complete set, you may be able to purchase the various chess pieces separately.", |
| 53 | + "is_response_0_safe": true, |
| 54 | + "is_response_1_safe": true, |
| 55 | + "better_response_id": 1, |
| 56 | + "safer_response_id": 1 |
| 57 | +} |
| 58 | +``` |
| 59 | + |
| 60 | +#### RLHF 数据 |
| 61 | +RLHF 阶段使用 prompt only 数据,另外可以可选的提供额外的监督数据用于构建 LM 损失约束 RLHF 训练。示例使用 [PKU-Alignment/PKU-SafeRLHF](https://huggingface.co/datasets/PKU-Alignment/PKU-SafeRLHF) 数据集(同样是人类偏好数据集,这里只使用其 prompt 字段并对 prompt 去重)。此外还使用了 [tatsu-lab/alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca)的数据来构建额外的损失项。 |
| 62 | + |
| 63 | +以上示例数据集在训练时将自动下载缓存使用。 |
| 64 | + |
| 65 | +#### 自定义数据 |
| 66 | +数据定义围绕 `RawSample` 和 `RawDataset` 两个预置的类进行;`RawSample` 提供了数据样本级别接入协议规范,`RawDataset` 提供了数据集级别接入协议规范;按照下面介绍的规范接入,即可通过预置的 `SupervisedDataset`、`PreferenceDataset`、`PromptOnlyDataset` 三类 RLHF 训练所需数据接口来使用自定义数据。 |
| 67 | + |
| 68 | +自定义数据集需要: |
| 69 | +- 继承 `RawDataset` ,并定义类属性 `NAME` 用于注册数据集。 |
| 70 | +- 实现 `__init__` 方法(加载数据),`__getitem__` 方法(根据 index 获取样本并转换为 `RawSample` 对象返回)、`__len__` 方法(数据集大小)。 |
| 71 | + |
| 72 | +示例如下: |
| 73 | + |
| 74 | +```python |
| 75 | +from datasets import load_dataset |
| 76 | +from data import RawDataset, RawSample |
| 77 | + |
| 78 | +class MyRawDataset(RawDataset): |
| 79 | + NAME = 'my-dataset-name' |
| 80 | + |
| 81 | + def __init__(self, path=None) -> None: |
| 82 | + # Load a dataset from Hugging Face or any other data source |
| 83 | + # self.data = load_dataset(path or 'my-organization/my-dataset')['train'] |
| 84 | + self.data = [{ |
| 85 | + 'col1': 'question', |
| 86 | + 'col2': 'answer1', |
| 87 | + 'col3': 'answer2', |
| 88 | + 'col4': 1, # score of answer1 |
| 89 | + 'col5': 2 # score of answer2 |
| 90 | + }] * 10 # dummy data for example |
| 91 | + |
| 92 | + def __getitem__(self, index: int) -> RawSample: |
| 93 | + data = self.data[index] |
| 94 | + # Construct a `RawSample` dictionary from your custom dataset item |
| 95 | + return RawSample( |
| 96 | + input=data['col1'], |
| 97 | + answer=data['col2'], |
| 98 | + other_answer=data['col3'], |
| 99 | + better=float(data['col4']) > float(data['col5']), |
| 100 | + ) |
| 101 | + |
| 102 | + def __len__(self) -> int: |
| 103 | + return len(self.data) # dataset size |
| 104 | +``` |
| 105 | + |
| 106 | +其中 `RawSample` 是整个 RLHF 训练过程用到的几种数据类型的超集,如下所示,其可以桥接各训练阶段所需样本类型。在自定义数据时,对于 SFT 数据使用 `RawSample` 的 `(input, answer)` 字段;对于人类偏好数据使用 `RawSample` 的 `(input, answer, other_answer, better)` 字段;对于 prompt only 数据,使用 `RawSample` 的 `(input)`字段。 |
| 107 | + |
| 108 | +```python |
| 109 | +class RawSample(TypedDict, total=False): |
| 110 | + """Raw sample type. |
| 111 | +
|
| 112 | + For SupervisedDataset, should provide (input, answer) or (dialogue). |
| 113 | + For PreferenceDataset, should provide (input, answer, other_answer, better). |
| 114 | + For SafetyPreferenceDataset, should provide (input, answer, other_answer, safer, is_safe, is_other_safe). |
| 115 | + For PromptOnlyDataset, should provide (input). |
| 116 | +
|
| 117 | + When input is a list, it would be processed as a dialogue. |
| 118 | + """ |
| 119 | + |
| 120 | + # Texts |
| 121 | + input: NotRequired[str | list[str]] # either `input` or `dialogue` should be provided |
| 122 | + """User input text.""" |
| 123 | + answer: NotRequired[str] |
| 124 | + """Assistant answer text.""" |
| 125 | + other_answer: NotRequired[str] |
| 126 | + """Other assistant answer text via resampling.""" |
| 127 | + dialogue: NotRequired[list[str]] # either `input` or `dialogue` should be provided |
| 128 | + """Dialogue history.""" |
| 129 | + |
| 130 | + # Flags |
| 131 | + better: NotRequired[bool] |
| 132 | + """Whether ``answer`` is better than ``other_answer``.""" |
| 133 | + safer: NotRequired[bool] |
| 134 | + """Whether ``answer`` is safer than ``other_answer``.""" |
| 135 | + is_safe: NotRequired[bool] |
| 136 | + """Whether ``answer`` is safe.""" |
| 137 | + is_other_safe: NotRequired[bool] |
| 138 | + """Whether ``other_answer`` is safe.""" |
| 139 | +``` |
| 140 | + |
| 141 | +如此定义的数据集将可以通过预置接口根据 `NAME` 来使用,当前内置支持`"PKU-SafeRLHF/train", "PKU-SafeRLHF/test", "PKU-SafeRLHF-30K/train", "PKU-SafeRLHF-30K/test", "PKU-SafeRLHF-10K/train", "alpaca"` 几个数据集。另外还支持使用多个数据集并指定数据比例,我们可以按照需要为每个阶段训练准备多份数据集。示例如下: |
| 142 | + |
| 143 | +```python |
| 144 | +from paddlenlp.transformers import AutoTokenizer |
| 145 | +from data import PreferenceDataset |
| 146 | + |
| 147 | +tokenizer = AutoTokenizer.from_pretrained('facebook/llama-7b') |
| 148 | +dataset = PreferenceDataset({ |
| 149 | + 'alpaca': 0.75, |
| 150 | + 'my-dataset-name': 0.5 |
| 151 | +}, tokenizer) |
| 152 | +``` |
| 153 | + |
| 154 | +### 训练 |
| 155 | + |
| 156 | +PPO 完整的训练过程包括以下 3 个阶段,如下图所示(来自[DeepSpeed-Chat](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat)): |
| 157 | + |
| 158 | +<p align="center"> |
| 159 | + <img src="https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/assets/image/ppo_trainer.png?raw=true" align="middle" width = "600" /> |
| 160 | +</p> |
| 161 | + |
| 162 | +1. Supervised Fine-Tuning (SFT) |
| 163 | + |
| 164 | +同[LLM 精调](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm#2-%E7%B2%BE%E8%B0%83),可以直接参考对应内容进行训练并使用其产出模型。 |
| 165 | + |
| 166 | +2. Reward Model Fine-Tuning |
| 167 | + |
| 168 | +使用 `reward_main.py` 脚本根据 `reward_config.json` 训练奖励模型 |
| 169 | + |
| 170 | +``` |
| 171 | +python -u -m paddle.distributed.launch reward_main.py ./reward_config.json |
| 172 | +``` |
| 173 | + |
| 174 | +`reward_config.json` 中的绝大部分参数释义同[LLM 精调](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm#2-%E7%B2%BE%E8%B0%83),不再赘述;稍有区别的是 `train_datasets`/`eval_datasets` 分别使用数据集定义注册时的`NAME`属性给出训练和验证集。另外对于奖励模型训练有以下特殊参数配置及释义(使用 PKU-Alignment/PKU-SafeRLHF 中的默认值): |
| 175 | + |
| 176 | +- `normalize_score_during_training`:是否在训练过程中对奖励进行 normalize,默认为 `False`。 |
| 177 | +- `normalizer_type`:使用 normalizer 时计算 mean、var 的方式,可选`"RunningMeanStd", "ExponentialMovingAverage"`。 |
| 178 | +- `normalizer_momentum`:使用 `ExponentialMovingAverage` normalizer 时指定的 momentum ,默认为 `0.9`。 |
| 179 | +- `loss_type`:使用 token 级或是 sequence 级 loss 进行奖励模型训练,可选`"token-wise", "sequence-wise"`,默认为 `"sequence-wise"`。 |
| 180 | +- `regularization`:奖励模型训练目标中对奖励的正则化系数,默认为 `0.001`。 |
| 181 | + |
| 182 | +3. RLHF: |
| 183 | + |
| 184 | +RLHF 阶段需要 actor model、reference model、critic model、reward model 四个模型;actor-model/reference-model 使用 SFT 模型进行 initialize/frozen;critic-model/reward-model 使用 reward 模型进行 initialize/frozen (另外注意若 SFT 使用 LoRA 请先将 LoRA 权重合并)。这里使用 PKU-Alignment/PKU-SafeRLHF 提供的 SFT 模型([PKU-Alignment/alpaca-7b-reproduced](https://huggingface.co/PKU-Alignment/alpaca-7b-reproduced))和 reward 模型([PKU-Alignment/beaver-7b-v1.0-reward](https://huggingface.co/PKU-Alignment/beaver-7b-v1.0-reward),注意该模型只关注 helpful 未考量 harmless)作为示例,使用 `ppo_main.py` 脚本根据 `ppo_config.json` 进行 RLHF 训练。 |
| 185 | + |
| 186 | +``` |
| 187 | +python -u -m paddle.distributed.launch ppo_main.py ./ppo_config.json |
| 188 | +``` |
| 189 | + |
| 190 | +`ppo_config.json` 中的绝大部分参数释义同[LLM 精调](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm#2-%E7%B2%BE%E8%B0%83),不再赘述,重点给出以下参数配置及释义(使用 PKU-Alignment/PKU-SafeRLHF 中的默认值): |
| 191 | + |
| 192 | +- `train_datasets`:使用数据集定义注册时的`NAME`属性给出训练集。 |
| 193 | +- `eval_datasets`:使用数据集定义注册时的`NAME`属性给出验证集。 |
| 194 | +- `ptx_datasets`:使用数据集定义注册时的`NAME`属性给出 ptx-loss 使用的数据集,未提供时将不使用 ptx-loss。 |
| 195 | +- `actor_model_name_or_path`:actor-model/reference-model 用来 initialize/frozen 的模型名称或目录。 |
| 196 | +- `reward_model_name_or_path`:reward-model 的模型名称或目录。 |
| 197 | +- `reward_critic_model_name_or_path`:critic-model 的模型名称或目录,未提供时将使用`reward_model_name_or_path`进行 critic-model 的初始化。 |
| 198 | +- `per_device_prompt_batch_size`:训练时 prompt only 数据集读取用于 rollout 生成的批次大小(每张卡)。 |
| 199 | +- `per_device_train_batch_size`:根据 prompt 进行生成及训练使用的批次大小(每张卡)。 |
| 200 | +- `num_return_sequences`:生成时每个 prompt 生成的回复个数,即 `GenerationConfig.num_return_sequences`,所有回复都将用来训练。 |
| 201 | +- `temperature`:生成采样时使用的 `temperature` ,即 `GenerationConfig.temperature`。 |
| 202 | +- `top_p`:生成采样时 top-p-filtering 阈值,即 `GenerationConfig.top_p`。 |
| 203 | +- `repetition_penalty`:生成采样时长度惩罚系数,即 `GenerationConfig.repetition_penalty`。 |
| 204 | +- `update_iters`:一次生成的数据被使用的次数。 |
| 205 | +- `kl_coeff`:对 reward 进行 KL-Penalty 的系数。 |
| 206 | +- `clip_range_score`:对 reward 进行裁剪的阈值。 |
| 207 | +- `clip_range_value`:critic model(value function)对当前sequence的新值与Experience Buffer中旧值的差距超过该范围将进行裁剪。 |
| 208 | +- `clip_range_ratio`:将当前sequence的新概率与Experience Buffer中旧概率比值裁剪到`(1-clip_range_ratio, 1+clip_range_ratio)`范围(PPO-Clip)。 |
| 209 | +- `ptx_coeff`: 预训练损失项 ptx-loss 的系数。 |
| 210 | + |
| 211 | +另外所有 [`TrainingArguments` 支持参数配置](https://paddlenlp.readthedocs.io/zh/latest/trainer.html#trainingarguments)将为 actor-model 和 critic-model 的训练复用(如`sharding_stage`),除单独提供了 `critic_learning_rate/critic_weight_decay/critic_lr_scheduler_type/critic_warmup_ratio/critic_recompute` 这些参数支持为 critic-model 训练单独指定相应配置。actor-model 和 critic-model 的 checkpoints 将分别保存在 `outpt_dir` 所指定目录的 policy 和 value 文件夹下。 |
| 212 | + |
| 213 | +当前示例中所用数据及规模 RLHF 训练基于 sharding stage3 使用 NVIDIA A100 80G 4卡/8卡训练验证。 |
| 214 | + |
| 215 | +### 推理 |
| 216 | + |
| 217 | +训练完成后可以直接使用 `outpt_dir` 所指定目录中 policy 文件夹下的 checkpoints 按照[LLM 推理](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm#4-%E6%8E%A8%E7%90%86)部分的介绍来进行推理,请参考相应部分内容。 |
| 218 | + |
| 219 | +## Acknowledge |
| 220 | + |
| 221 | +我们借鉴了[PKU-Alignment/safe-rlhf](https://github.com/PKU-Alignment/safe-rlhf)(PKU Beaver)的优秀设计实现,在此对其作者表示感谢。 |
| 222 | + |
| 223 | +## 参考文献 |
| 224 | +- Zheng R, Dou S, Gao S, et al. Secrets of rlhf in large language models part i: Ppo[J]. arXiv preprint arXiv:2307.04964, 2023. |
| 225 | +- Dai J, Pan X, Sun R, et al. Safe rlhf: Safe reinforcement learning from human feedback[J]. arXiv preprint arXiv:2310.12773, 2023. |
0 commit comments