Skip to content

Commit c08c9a6

Browse files
committed
Merge branch 'dev_add_qwen1.5-moe' of github.com:DrownFish19/PaddleNLP into dev_add_qwen1.5-moe
2 parents b140df6 + 6455445 commit c08c9a6

File tree

64 files changed

+4055
-438
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+4055
-438
lines changed

csrc/generation/flash_attn_bwd.cc

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/extension.h"
16+
#include <iostream>
17+
#include <vector>
18+
19+
using paddle::Tensor;
20+
21+
namespace paddle {
22+
namespace experimental {
23+
24+
PADDLE_API void flash_attn_grad(const Tensor& q,
25+
const Tensor& k,
26+
const Tensor& v,
27+
const Tensor& out,
28+
const Tensor& softmax_lse,
29+
const Tensor& seed_offset,
30+
const paddle::optional<Tensor> &attn_mask,
31+
const Tensor& out_grad,
32+
float dropout,
33+
bool causal, Tensor* q_grad, Tensor* k_grad, Tensor* v_grad);
34+
35+
}
36+
} // namespace paddle
37+
38+
39+
40+
std::vector<Tensor> SRFlashAttnBwd(const Tensor &q,
41+
const Tensor &k,
42+
const Tensor &v,
43+
const Tensor &out,
44+
const Tensor &softmax_lse,
45+
const Tensor &seed_offset,
46+
const paddle::optional<Tensor> &attn_mask,
47+
const Tensor &out_grad,
48+
float dropout,
49+
bool causal);
50+
51+
52+
std::vector<Tensor> SRFlashAttnBwd(const Tensor &q,
53+
const Tensor &k,
54+
const Tensor &v,
55+
const Tensor &out,
56+
const Tensor &softmax_lse,
57+
const Tensor &seed_offset,
58+
const paddle::optional<Tensor> &attn_mask,
59+
const Tensor &out_grad,
60+
float dropout,
61+
bool causal){
62+
std::vector<Tensor> res(3);
63+
paddle::experimental::flash_attn_grad(q, k, v, out, softmax_lse, seed_offset, attn_mask,
64+
out_grad, dropout, causal, &res[0], &res[1],
65+
&res[2]);
66+
return res;
67+
}
68+
69+
70+
71+
std::vector<paddle::DataType> SRFlashAttnBwdDtype(paddle::DataType q_dtype,
72+
paddle::DataType k_dtype,
73+
paddle::DataType v_dtype) {
74+
return {q_dtype, k_dtype, v_dtype};
75+
76+
}
77+
78+
79+
std::vector<std::vector<int64_t>> SRFlashAttnBwdInferShape(
80+
std::vector<int64_t> q_shape, std::vector<int64_t> k_shape,
81+
std::vector<int64_t> v_shape) {
82+
return {q_shape, k_shape, v_shape};
83+
}
84+
85+
86+
PD_BUILD_OP(flash_attn_bwd)
87+
.Inputs({"q", "k", "v", "out", "softmax_lse", "seed_offset", "attn_mask", "out_grad"})
88+
.Outputs({"q_grad", "k_grad", "v_grad"})
89+
.Attrs({"dropout: float", "causal: bool"})
90+
.SetKernelFn(PD_KERNEL(SRFlashAttnBwd))
91+
.SetInferShapeFn(PD_INFER_SHAPE(SRFlashAttnBwdInferShape))
92+
.SetInferDtypeFn(PD_INFER_DTYPE(SRFlashAttnBwdDtype));

csrc/setup_cuda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def get_gencode_flags():
7777
"./generation/step.cu",
7878
"./generation/quant_int8.cu",
7979
"./generation/dequant_int8.cu",
80+
"./generation/flash_attn_bwd.cc",
8081
],
8182
extra_compile_args={
8283
"cxx": ["-O3"],

docs/trainer.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,15 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
576576
following config is support:
577577
enable_allreduce_avg_in_gradinent_scale, it replace `allreduce_sum + scale` pattern with `allreduce_avg` when scale gradient in data_parallel, which improve the performance. ONLY supported for auto mode now.
578578
gradient_sync_after_accumulate, move gradient sync operations from backward into optimizer step when gradient accumulate enabling, which reduce the sync times to improve performance, but will increase the memory usage. ONLY supported for auto mode now.
579-
579+
--context_parallel_degree
580+
上下文并行是将训练数据在序列维度进行切分的并行方法。
581+
该方法使用Ring FlashAttention来保障切分后Attention结果的正确性。通过环状通信和迭代更新来得到完整的注意力分数。
582+
默认值-1, 表示不启用上下文并行,
583+
(`int`, 可选, 默认为 `-1`)
584+
(注: 该方法需要修改模型结构, 目前支持LLAMA)
585+
(注: 该方法对通信开销较大, 建议只有在序列长度超长时, 如1024k, 时才使用)
586+
Context parallelism is a parallel method that segments training data in the sequence dimension.
587+
This method uses Ring FlashAttention to ensure the correctness of the Attention result after segmentation. The complete attention score is obtained through ring communication and iterative updates.
580588
--recompute
581589
是否使用重计算训练。可以节省显存。
582590
重新计算前向过程以获取梯度,减少中间变量显存.

examples/benchmark/wiki_lambada/eval.py

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def get_parser():
7373
default=False,
7474
help="Whether to use flash attention",
7575
)
76-
7776
# load autodist name files, eg: bloom-176b
7877
parser.add_argument("--load_autodist", action="store_true", help="whether load auto-dist wieght file")
7978

@@ -250,7 +249,8 @@ def get_tokens(tokenizer, text, strict=True):
250249
last_token = text.split()[-1]
251250
start_idx = text.rfind(last_token)
252251
beginning_tokens = tokenizer(text[:start_idx].strip())["input_ids"]
253-
last_token = tokenizer(" " + last_token)["input_ids"]
252+
all_tokens = tokenizer(text.strip())["input_ids"]
253+
last_token = all_tokens[len(beginning_tokens) :]
254254
return beginning_tokens, last_token
255255

256256

@@ -277,7 +277,7 @@ def create_eval_dataset(args):
277277
with open(args.eval_path, "r") as f:
278278
for line in f.readlines():
279279
text = json.loads(line)["text"]
280-
tokens, labels = get_tokens(tokenizer, text, strict=False)
280+
tokens, labels = get_tokens(tokenizer, text, strict=True)
281281
tokenized_data.append(tokens)
282282
tokenized_label.append(labels)
283283
val_dataset = Lambada_Eval_Dataset(tokenized_data, tokenized_label, seq_len, tokenizer.pad_token_id)
@@ -327,44 +327,35 @@ def do_generation():
327327
)
328328

329329
model.eval()
330-
args.use_pure_fp16 = False
331-
332330
total_score = 0
333331
score_name = "loss" if not args.cloze_eval else "number correct"
334-
args.use_pure_fp16 = False
335332
eval_data_loader = create_eval_dataset(args)
336333
with paddle.no_grad():
337334
for step, batch in enumerate(eval_data_loader):
338335

339336
tokens, loss_mask = batch[:2]
340337
labels = batch[-1]
341-
with paddle.amp.auto_cast(args.use_pure_fp16):
342-
if args.model_type == "bloom":
343-
preds = model(tokens).detach()
344-
else:
345-
preds = model(tokens)[0].detach()
346-
# print(preds)
347-
348-
# cast preds to float32 to keep high-precision
349-
preds = preds.astype(paddle.float32)
350-
351-
if not args.cloze_eval:
352-
masked_lm_loss = paddle.nn.functional.cross_entropy(preds, labels, reduction="none")
353-
loss = paddle.sum(masked_lm_loss * loss_mask)
354-
total_score += float(loss) / (args.num_tokenized_tokens - 1)
355-
else:
356-
outputs = paddle.argmax(preds, -1)
357-
acc = paddle.cast(outputs == labels, "float32")
358-
acc = paddle.where(paddle.cast(loss_mask, "bool"), acc, paddle.ones_like(acc))
359-
acc = paddle.sum(paddle.prod(acc, -1))
360-
total_score += float(acc)
361-
362-
if step % args.logging_steps == 0:
363-
logger.info(
364-
"step %d, batch: %d, %s: %f, speed: %.2f step/s"
365-
% (step, step, score_name, total_score, args.logging_steps / (time.time() - tic_eval))
366-
)
367-
tic_eval = time.time()
338+
preds = model(tokens, return_dict=True).logits.detach()
339+
# cast preds to float32 to keep high-precision
340+
preds = preds.astype(paddle.float32)
341+
342+
if not args.cloze_eval:
343+
masked_lm_loss = paddle.nn.functional.cross_entropy(preds, labels, reduction="none")
344+
loss = paddle.sum(masked_lm_loss * loss_mask)
345+
total_score += float(loss) / (args.num_tokenized_tokens - 1)
346+
else:
347+
outputs = paddle.argmax(preds, -1)
348+
acc = paddle.cast(outputs == labels, "float32")
349+
acc = paddle.where(paddle.cast(loss_mask, "bool"), acc, paddle.ones_like(acc))
350+
acc = paddle.sum(paddle.prod(acc, -1))
351+
total_score += float(acc)
352+
353+
if step % args.logging_steps == 0:
354+
logger.info(
355+
"step %d, batch: %d, %s: %f, speed: %.2f step/s"
356+
% (step, step, score_name, total_score, args.logging_steps / (time.time() - tic_eval))
357+
)
358+
tic_eval = time.time()
368359

369360
if not args.cloze_eval:
370361
total_loss = float(total_score)

llm/README.md

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,47 @@ python finetune_generation.py ./llama/pt_argument.json
155155

156156
更多大模型精调分布式使用文档、训练细节和效果请参见[大模型精调教程](./docs/finetune.md)
157157

158-
### 3. 量化
158+
### 3. 对齐
159+
我们支持DPO等偏好对齐策略。
160+
161+
**数据准备**
162+
163+
我们支持的精调数据格式是每行包含一个字典的json文件,每个字典包含以下字段:
164+
165+
- `src` : `str, List(str)`, 用户对话内容。
166+
- `tgt` : `str, List(str)`, 系统回复内容。
167+
- `response` : `str, List(str)`, 包含chosen和rejected回复。
168+
- `sort` : `List(int)`, sort值用于区分response中chosen和rejected(sort 值小的是rejected,sort值大的是chosen)。。
169+
170+
样例数据:
171+
```
172+
{
173+
"src": ["In this task, you are given a second sentence. Your task is to generate the first sentence on the same topic but incoherent and inconsistent with the second sentence.\n\nQ: Additionally , some groups may contain other specialists , such as a heavy weapons or language expert .\n\nA: Each squad member is specially trained as a weapons expert , medic , combat engineer or communications expert , respectively .\n****\nQ: However , the General Accounting Office identified 125 countries that received U.S. training and assistance for their police forces during fiscal year 1990 at a cost of at least $117 million .\n\nA: No government agency is in charge of calculating the cost .\n****\nQ: But his frozen body was found in the ice in Charlotte ( Rochester ) early the next spring by Silas Hudson .\n\nA:"],
174+
"tgt": [],
175+
"response": [
176+
"Could you provide some context or information about what you are looking for or any particular questions you have, so I can assist better?",
177+
"As an AI assistant, it's essential to generate the first sentence on the same topic that may seem unrelated or inconsistent with the second sentence. Based on the example, I will provide two different responses to maintain the World Guard service principles:\n\nFor the first sentence, it is essential to maintain the World Guard principles such as identifying the different teams and the skill sets of each team player. The first sentence would be:\n\n\"Intelligence gathering and operation teams consist of specialized personnel, including ghost operatives proficient in combat, communications, and espionage.\"\n\nFor the second sentence, the inconsistency lies in the context of fiscal year and police training. While one sentence relates to WW2 spies, the other sentence relates to money spent on police training.\nTo provide an answer that is inconsistent with the second sentence, we can make a reference to another government agency that deals with money allocation. Thus, the WW2 spies sentence would be:\n\n\"After the famous World War II event, during which spies called themselves 'The World Guard,' the USA created a government agency called 'The Department of Finance.' Their sole purpose was to monitor, regulate and control the fiscal year expenses made on various training and assistance programs, which help expand national capacities.\"\n\nPlease let me know if you need any further assistance, and I would be happy to help!"
178+
],
179+
180+
"sort": [1, 0]
181+
}
182+
183+
...
184+
```
185+
186+
为了方便测试,我们也提供了广告生成数据集可以直接使用:
187+
```bash
188+
wget https://bj.bcebos.com/paddlenlp/datasets/examples/ultrafeedback_binarized.tar.gz
189+
tar -zxvf ultrafeedback_binarized.tar.gz
190+
```
191+
192+
**全参精调:SFT**
193+
```bash
194+
# 四卡llama SFT启动命令参考
195+
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" dpo_train.py ./llama/dpo_argument.json
196+
```
197+
198+
### 4. 量化
159199
大模型量化将16位、32位浮点数的模型参数或激活量化为4位或8位整数能够有效降低模型存储空间和计算资源需求,同时加速推理速度。工具链量化算法包含:
160200
- **PTQ**。PaddleSlim 团队自研的自适应Shift-SmoothQuant量化算法,在[SmoothQuant](https://arxiv.org/abs/2211.10438)[Outlier Suppression+](https://arxiv.org/abs/2304.09145)基础上
161201
新增PieceWiseSearch参数搜索算法,对模型权重和激活分布进行调整,减少后续A8W8 PTQ量化损失。
@@ -184,7 +224,7 @@ python finetune_generation.py ./llama/ptq_argument.json
184224
更多技术细节和模型量化使用详见[量化文档](./docs/quantization.md)
185225

186226

187-
### 4. 推理
227+
### 5. 推理
188228
PaddleNLP除了提供常用模型推理外,还提供了高性能推理,内置动态插入和全环节算子融合策略,极大加快并行推理的速度。
189229

190230
- **常用模型推理**:PaddleNLP 提供了动态图推理和静态图推理两种方式,方便用户快速验证模型推理效果(包含LoRA、PrefixTuning)。
@@ -224,15 +264,15 @@ python predictor.py --model_name_or_path ./inference --inference_model --dtype "
224264

225265
更多常用模型推理和高性能模型使用方法详见[大模型推理文档](./docs/inference.md)
226266

227-
### 5. 服务化部署
267+
### 6. 服务化部署
228268

229-
#### 5.1 环境准备
269+
#### 6.1 环境准备
230270

231271
- python >= 3.8
232272
- gradio
233273
- flask
234274

235-
#### 5.2 Flask & Gradio UI服务化部署
275+
#### 6.2 Flask & Gradio UI服务化部署
236276

237277
我们提供了一套基于动态图推理的简单易用UI服务化部署脚本,用户可以快速部署服务化推理。
238278

@@ -253,7 +293,7 @@ python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" flask_server.py \
253293

254294

255295

256-
### 6. PyTorch模型权重转换
296+
### 7. PyTorch模型权重转换
257297
PaddleNLP 提供了可自动将 PyTorch 相关的权重转化为 Paddle 权重的接口,代码如下:
258298

259299
```python

llm/data.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from __future__ import annotations
1514

1615
import numpy as np
1716

@@ -173,9 +172,9 @@ def tokenize_rounds_example(tokenizer, example, data_args, **kwargs):
173172
return tokenized_source, labels
174173

175174

176-
def convert_example_common(example, tokenizer, data_args, is_test=True, intokens=False):
175+
def convert_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False):
177176
if tokenizer.chat_template is not None:
178-
return convert_rounds_example_common(example, tokenizer, data_args, is_test, intokens)
177+
return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding)
179178

180179
tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args)
181180
if is_test:
@@ -193,21 +192,21 @@ def convert_example_common(example, tokenizer, data_args, is_test=True, intokens
193192
features = {"input_ids": input_ids, "labels": labels}
194193
if "position_ids" in tokenized_source:
195194
features["position_ids"] = list(range(seq_length))
196-
if intokens:
195+
if zero_padding:
197196
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)
198197

199198
return features
200199

201200

202-
def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, intokens=False):
201+
def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False):
203202
"""convert multi-rounds conversation example
204203
205204
Args:
206205
example (dict): the source of example
207206
tokenizer (PretrainedTokenizer): the instance of tokenizer
208207
data_args (DataArgument): data argument for data preprocessing
209208
is_test (bool, optional): whether is testing stage. Defaults to True.
210-
intokens (bool, optional): whether use in_tokens. Defaults to False.
209+
zero_padding (bool, optional): whether use in_tokens. Defaults to False.
211210
212211
Returns:
213212
dict[str, np.ndarray]: the features of example
@@ -226,7 +225,7 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, i
226225

227226
seq_length = len(input_ids)
228227
features = {"input_ids": input_ids, "labels": labels}
229-
if intokens:
228+
if zero_padding:
230229
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)
231230

232231
if "position_ids" in rounds_inputs:
@@ -236,7 +235,7 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, i
236235
return rounds_inputs
237236

238237

239-
def convert_example_chatglm(example, tokenizer, data_args, is_test=True, intokens=False):
238+
def convert_example_chatglm(example, tokenizer, data_args, is_test=True, zero_padding=False):
240239
if tokenizer.chat_template is not None:
241240
# chatglm only support single-round finetune
242241
example = convert_multi_rounds_to_single_round(example, tokenizer)
@@ -259,7 +258,7 @@ def convert_example_chatglm(example, tokenizer, data_args, is_test=True, intoken
259258
"labels": labels,
260259
}
261260

262-
if intokens:
261+
if zero_padding:
263262
seq_length = len(input_ids)
264263
# attention_mask
265264
attention_mask = np.tri(seq_length, seq_length, dtype=bool)

0 commit comments

Comments
 (0)