Skip to content

Commit 909be01

Browse files
authored
add llama & qwen dpo (#8474)
* add llama&qwen dpo * add * add dpo * fix bug * add
1 parent 547d29c commit 909be01

19 files changed

+1513
-80
lines changed

llm/README.md

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,47 @@ python finetune_generation.py ./llama/pt_argument.json
155155

156156
更多大模型精调分布式使用文档、训练细节和效果请参见[大模型精调教程](./docs/finetune.md)
157157

158-
### 3. 量化
158+
### 3. 对齐
159+
我们支持DPO等偏好对齐策略。
160+
161+
**数据准备**
162+
163+
我们支持的精调数据格式是每行包含一个字典的json文件,每个字典包含以下字段:
164+
165+
- `src` : `str, List(str)`, 用户对话内容。
166+
- `tgt` : `str, List(str)`, 系统回复内容。
167+
- `response` : `str, List(str)`, 包含chosen和rejected回复。
168+
- `sort` : `List(int)`, sort值用于区分response中chosen和rejected(sort 值小的是rejected,sort值大的是chosen)。。
169+
170+
样例数据:
171+
```
172+
{
173+
"src": ["In this task, you are given a second sentence. Your task is to generate the first sentence on the same topic but incoherent and inconsistent with the second sentence.\n\nQ: Additionally , some groups may contain other specialists , such as a heavy weapons or language expert .\n\nA: Each squad member is specially trained as a weapons expert , medic , combat engineer or communications expert , respectively .\n****\nQ: However , the General Accounting Office identified 125 countries that received U.S. training and assistance for their police forces during fiscal year 1990 at a cost of at least $117 million .\n\nA: No government agency is in charge of calculating the cost .\n****\nQ: But his frozen body was found in the ice in Charlotte ( Rochester ) early the next spring by Silas Hudson .\n\nA:"],
174+
"tgt": [],
175+
"response": [
176+
"Could you provide some context or information about what you are looking for or any particular questions you have, so I can assist better?",
177+
"As an AI assistant, it's essential to generate the first sentence on the same topic that may seem unrelated or inconsistent with the second sentence. Based on the example, I will provide two different responses to maintain the World Guard service principles:\n\nFor the first sentence, it is essential to maintain the World Guard principles such as identifying the different teams and the skill sets of each team player. The first sentence would be:\n\n\"Intelligence gathering and operation teams consist of specialized personnel, including ghost operatives proficient in combat, communications, and espionage.\"\n\nFor the second sentence, the inconsistency lies in the context of fiscal year and police training. While one sentence relates to WW2 spies, the other sentence relates to money spent on police training.\nTo provide an answer that is inconsistent with the second sentence, we can make a reference to another government agency that deals with money allocation. Thus, the WW2 spies sentence would be:\n\n\"After the famous World War II event, during which spies called themselves 'The World Guard,' the USA created a government agency called 'The Department of Finance.' Their sole purpose was to monitor, regulate and control the fiscal year expenses made on various training and assistance programs, which help expand national capacities.\"\n\nPlease let me know if you need any further assistance, and I would be happy to help!"
178+
],
179+
180+
"sort": [1, 0]
181+
}
182+
183+
...
184+
```
185+
186+
为了方便测试,我们也提供了广告生成数据集可以直接使用:
187+
```bash
188+
wget https://bj.bcebos.com/paddlenlp/datasets/examples/ultrafeedback_binarized.tar.gz
189+
tar -zxvf ultrafeedback_binarized.tar.gz
190+
```
191+
192+
**全参精调:SFT**
193+
```bash
194+
# 四卡llama SFT启动命令参考
195+
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" dpo_train.py ./llama/dpo_argument.json
196+
```
197+
198+
### 4. 量化
159199
大模型量化将16位、32位浮点数的模型参数或激活量化为4位或8位整数能够有效降低模型存储空间和计算资源需求,同时加速推理速度。工具链量化算法包含:
160200
- **PTQ**。PaddleSlim 团队自研的自适应Shift-SmoothQuant量化算法,在[SmoothQuant](https://arxiv.org/abs/2211.10438)[Outlier Suppression+](https://arxiv.org/abs/2304.09145)基础上
161201
新增PieceWiseSearch参数搜索算法,对模型权重和激活分布进行调整,减少后续A8W8 PTQ量化损失。
@@ -184,7 +224,7 @@ python finetune_generation.py ./llama/ptq_argument.json
184224
更多技术细节和模型量化使用详见[量化文档](./docs/quantization.md)
185225

186226

187-
### 4. 推理
227+
### 5. 推理
188228
PaddleNLP除了提供常用模型推理外,还提供了高性能推理,内置动态插入和全环节算子融合策略,极大加快并行推理的速度。
189229

190230
- **常用模型推理**:PaddleNLP 提供了动态图推理和静态图推理两种方式,方便用户快速验证模型推理效果(包含LoRA、PrefixTuning)。
@@ -224,15 +264,15 @@ python predictor.py --model_name_or_path ./inference --inference_model --dtype "
224264

225265
更多常用模型推理和高性能模型使用方法详见[大模型推理文档](./docs/inference.md)
226266

227-
### 5. 服务化部署
267+
### 6. 服务化部署
228268

229-
#### 5.1 环境准备
269+
#### 6.1 环境准备
230270

231271
- python >= 3.8
232272
- gradio
233273
- flask
234274

235-
#### 5.2 Flask & Gradio UI服务化部署
275+
#### 6.2 Flask & Gradio UI服务化部署
236276

237277
我们提供了一套基于动态图推理的简单易用UI服务化部署脚本,用户可以快速部署服务化推理。
238278

@@ -253,7 +293,7 @@ python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" flask_server.py \
253293

254294

255295

256-
### 6. PyTorch模型权重转换
296+
### 7. PyTorch模型权重转换
257297
PaddleNLP 提供了可自动将 PyTorch 相关的权重转化为 Paddle 权重的接口,代码如下:
258298

259299
```python

llm/data.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from __future__ import annotations
1514

1615
import numpy as np
1716

@@ -163,9 +162,9 @@ def tokenize_rounds_example(tokenizer, example, data_args, **kwargs):
163162
return tokenized_source, labels
164163

165164

166-
def convert_example_common(example, tokenizer, data_args, is_test=True, intokens=False):
165+
def convert_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False):
167166
if tokenizer.chat_template is not None:
168-
return convert_rounds_example_common(example, tokenizer, data_args, is_test, intokens)
167+
return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding)
169168

170169
tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args)
171170
if is_test:
@@ -183,21 +182,21 @@ def convert_example_common(example, tokenizer, data_args, is_test=True, intokens
183182
features = {"input_ids": input_ids, "labels": labels}
184183
if "position_ids" in tokenized_source:
185184
features["position_ids"] = list(range(seq_length))
186-
if intokens:
185+
if zero_padding:
187186
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)
188187

189188
return features
190189

191190

192-
def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, intokens=False):
191+
def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False):
193192
"""convert multi-rounds conversation example
194193
195194
Args:
196195
example (dict): the source of example
197196
tokenizer (PretrainedTokenizer): the instance of tokenizer
198197
data_args (DataArgument): data argument for data preprocessing
199198
is_test (bool, optional): whether is testing stage. Defaults to True.
200-
intokens (bool, optional): whether use in_tokens. Defaults to False.
199+
zero_padding (bool, optional): whether use in_tokens. Defaults to False.
201200
202201
Returns:
203202
dict[str, np.ndarray]: the features of example
@@ -216,7 +215,7 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, i
216215

217216
seq_length = len(input_ids)
218217
features = {"input_ids": input_ids, "labels": labels}
219-
if intokens:
218+
if zero_padding:
220219
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)
221220

222221
if "position_ids" in rounds_inputs:
@@ -226,7 +225,7 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, i
226225
return rounds_inputs
227226

228227

229-
def convert_example_chatglm(example, tokenizer, data_args, is_test=True, intokens=False):
228+
def convert_example_chatglm(example, tokenizer, data_args, is_test=True, zero_padding=False):
230229
if tokenizer.chat_template is not None:
231230
# chatglm only support single-round finetune
232231
example = convert_multi_rounds_to_single_round(example, tokenizer)
@@ -249,7 +248,7 @@ def convert_example_chatglm(example, tokenizer, data_args, is_test=True, intoken
249248
"labels": labels,
250249
}
251250

252-
if intokens:
251+
if zero_padding:
253252
seq_length = len(input_ids)
254253
# attention_mask
255254
attention_mask = np.tri(seq_length, seq_length, dtype=bool)

llm/dpo_argument.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from dataclasses import dataclass, field
17+
from typing import Optional
18+
19+
from paddlenlp.trainer import TrainingArguments
20+
21+
22+
def add_start_docstrings(*docstr):
23+
"""Adds docstrings for a function."""
24+
25+
def docstring_decorator(fn):
26+
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
27+
return fn
28+
29+
return docstring_decorator
30+
31+
32+
@dataclass
33+
@add_start_docstrings(TrainingArguments.__doc__)
34+
class DPOTrainingArguments(TrainingArguments):
35+
"""DPOTrainingArguments"""
36+
37+
unified_checkpoint: bool = field(
38+
default=True,
39+
metadata={"help": "Enable fused linear grad add strategy."},
40+
)
41+
unified_checkpoint_config: Optional[str] = field(
42+
default="",
43+
metadata={"help": "Configs to unify hybrid parallel checkpoint.\n"},
44+
)
45+
dpo_beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
46+
dpo_label_smoothing: float = field(default=0.0, metadata={"help": "label_smoothing ratio"})
47+
dpo_loss_type: str = field(default="sigmoid", metadata={"help": "DPO loss type"})
48+
49+
50+
@dataclass
51+
class DPODataArgument:
52+
"""DataArgument"""
53+
54+
train_dataset_path: str = field(default="./data/train.jsonl", metadata={"help": "Path to the train dataset dir."})
55+
dev_dataset_path: str = field(default="./data/dev.jsonl", metadata={"help": "Path to the dev dataset dir."})
56+
max_seq_len: int = field(default=4096, metadata={"help": "Maximum sequence length."})
57+
max_prompt_len: int = field(default=2048, metadata={"help": "Maximum prompt length."})
58+
autotuner_benchmark: bool = field(
59+
default=False,
60+
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
61+
)
62+
benchmark: bool = field(
63+
default=False,
64+
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
65+
)
66+
greedy_intokens: bool = field(
67+
default=True,
68+
metadata={"help": "Whether apply greedy intokens."},
69+
)
70+
buffer_size: int = field(default=500, metadata={"help": "Buffer size for greedy_intokens strategy."})
71+
72+
73+
@dataclass
74+
class DPOModelArgument:
75+
"""ModelArgument"""
76+
77+
model_name_or_path: str = field(
78+
default=None, metadata={"help": "Pretrained model name or path to local directory."}
79+
)
80+
tokenizer_name_or_path: Optional[str] = field(
81+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
82+
)
83+
use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"})
84+
recompute_granularity: str = field(
85+
default="full",
86+
metadata={
87+
"help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`."
88+
},
89+
)
90+
use_attn_mask_start_row_indices: bool = field(
91+
default=False, metadata={"help": "Whether to use attn_mask_start_row_indices in flash attention."}
92+
)
93+
virtual_pp_degree: int = field(
94+
default=1,
95+
metadata={"help": "virtual_pp_degree"},
96+
)
97+
sequence_parallel: bool = field(
98+
default=False,
99+
metadata={"help": "whether to use sequence parallel"},
100+
)

0 commit comments

Comments
 (0)