Skip to content

Commit 4e814b7

Browse files
NeroLohneroluo
authored andcommitted
[xpu] add xpu custom ops support for llama2-7b
1 parent c1cfe63 commit 4e814b7

40 files changed

+4289
-0
lines changed

csrc/xpu/README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# ernie-bot-custom-ops
2+
ernie bot 昆仑自定义算子库。
3+
4+
## 快速开始
5+
# 构建 XDNN plugin 和 Paddle 自定义算子库
6+
```
7+
$ cd src
8+
$ wget https://baidu-kunlun-product.su.bcebos.com/KL-SDK/klsdk-dev/20240429/xdnn-ubuntu_x86_64.tar.gz
9+
$ wget https://baidu-kunlun-product.su.bcebos.com/KL-SDK/klsdk-dev/20240429/xre-ubuntu_x86_64.tar.gz
10+
$ wget -q --no-check-certificate https://klx-sdk-release-public.su.bcebos.com/xtdk_llvm15/dev/2.7.98.2/xtdk-llvm15-ubuntu1604_x86_64.tar.gz
11+
$ tar -xf xdnn-ubuntu_x86_64.tar.gz
12+
$ tar -xf xre-ubuntu_x86_64.tar.gz
13+
$ tar -xf xtdk-llvm15-ubuntu1604_x86_64.tar.gz
14+
$ export PWD=$(pwd)
15+
$ export XDNN_PATH=${PWD}/xdnn-ubuntu_x86_64/
16+
$ export XRE_PATH=${PWD}/xre-ubuntu_x86_64/
17+
$ export CLANG_PATH=${PWD}/xtdk-llvm15-ubuntu1604_x86_64/
18+
$ bash ./cmake_build.sh
19+
```
20+
21+
## 测试
22+
# 运行 add2 单测
23+
```
24+
$ cd test/python
25+
$ python test_get_padding_offset_v2.py
26+
```
27+
28+
## 如何贡献
29+
```
30+
$ pip install pre-commit==2.17.0
31+
$ pre-commit install
32+
```
33+
34+
## 讨论
35+
如果遇到问题,可以联系 luowei14@baidu.com, zhupengyang@baidu.com, shentanyue01@baidu.com, jiangfan06@baidu.comhongming@baidu.com 解决。

csrc/xpu/src/cmake_build.sh

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/bin/bash
2+
3+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
set -e
18+
19+
# export XDNN_PATH=Paddle/build/third_party/xpu/src/extern_xpu/xdnn-ubuntu_x86_64/ # <path_to_xdnn>
20+
# export XRE_PATH=Paddle/build/third_party/xpu/src/extern_xpu/xre-ubuntu_x86_64/ # <path_to_xre>
21+
# export CLANG_PATH=xtdk-ubuntu_1604_x86_64 # <path_to_xtdk>
22+
# export HOST_SYSROOT=/opt/compiler/gcc-8.2/bin/gcc # <path_to_gcc>
23+
24+
cd plugin
25+
./cmake_build.sh
26+
cd -
27+
28+
python -m pip uninstall paddlenlp_ops -y
29+
python setup.py install

csrc/xpu/src/get_output.cc

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <stdio.h>
16+
#include <string.h>
17+
#include <sys/ipc.h>
18+
#include <sys/msg.h>
19+
#include <sys/types.h>
20+
#include "paddle/extension.h"
21+
22+
#define MAX_BSZ 512
23+
24+
struct msgdata {
25+
long mtype;
26+
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
27+
};
28+
29+
void GetOutput(const paddle::Tensor& x,
30+
int64_t rank_id,
31+
bool wait_flag) {
32+
if (rank_id > 0) return;
33+
34+
static struct msgdata msg_rcv;
35+
36+
static key_t key = ftok("./", 1);
37+
38+
static int msgid = msgget(key, IPC_CREAT | 0666);
39+
40+
int64_t *out_data = const_cast<int64_t*>(x.data<int64_t>());
41+
int ret = -1;
42+
if (!wait_flag) {
43+
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
44+
} else {
45+
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
46+
}
47+
if(ret == -1)
48+
{
49+
// read none
50+
out_data[0] = -2;
51+
out_data[1] = 0;
52+
return;
53+
}
54+
55+
int bsz = msg_rcv.mtext[1];
56+
57+
for (int64_t i = 0; i < bsz + 2; i++) {
58+
out_data[i] = (int64_t)msg_rcv.mtext[i];
59+
}
60+
return;
61+
}
62+
63+
PD_BUILD_OP(get_output)
64+
.Inputs({"x"})
65+
.Attrs({"rank_id: int64_t",
66+
"wait_flag: bool"})
67+
.Outputs({"x_out"})
68+
.SetInplaceMap({{"x", "x_out"}})
69+
.SetKernelFn(PD_KERNEL(GetOutput));

csrc/xpu/src/get_padding_offset_v2.cc

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <paddle/phi/backends/xpu/xpu_context.h>
16+
#include "paddle/extension.h"
17+
#include "xpu/plugin.h"
18+
19+
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
20+
const paddle::Tensor& cum_offsets,
21+
const paddle::Tensor& token_num,
22+
const paddle::Tensor& seq_len) {
23+
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
24+
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
25+
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
26+
27+
std::vector<int64_t> input_ids_shape = input_ids.shape();
28+
const int bsz = seq_len.shape()[0];
29+
const int seq_length = input_ids_shape[1];
30+
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
31+
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
32+
33+
34+
const int token_num_data = cpu_token_num.data<int64_t>()[0];
35+
auto x_remove_padding = paddle::full(
36+
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
37+
auto padding_offset = paddle::full(
38+
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
39+
auto cu_seqlens_q =
40+
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
41+
auto cu_seqlens_k =
42+
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
43+
int r = baidu::xpu::api::plugin::get_padding_offset(
44+
xpu_ctx->x_context(),
45+
padding_offset.data<int>(),
46+
cum_offsets_out.data<int>(),
47+
cu_seqlens_q.data<int>(),
48+
cu_seqlens_k.data<int>(),
49+
x_remove_padding.data<int64_t>(),
50+
input_ids.data<int64_t>(),
51+
cum_offsets.data<int>(),
52+
seq_len.data<int>(),
53+
seq_length,
54+
bsz);
55+
PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed.");
56+
return {x_remove_padding,
57+
cum_offsets_out,
58+
padding_offset,
59+
cu_seqlens_q,
60+
cu_seqlens_k};
61+
}
62+
63+
std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
64+
const std::vector<int64_t>& input_ids_shape,
65+
const std::vector<int64_t>& cum_offsets_shape,
66+
const std::vector<int64_t>& token_num_shape,
67+
const std::vector<int64_t>& seq_len_shape) {
68+
int64_t bsz = seq_len_shape[0];
69+
int64_t seq_len = input_ids_shape[1];
70+
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
71+
}
72+
73+
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
74+
const paddle::DataType& input_ids_dtype,
75+
const paddle::DataType& cum_offsets_dtype,
76+
const paddle::DataType& token_num_dtype,
77+
const paddle::DataType& seq_len_dtype) {
78+
return {input_ids_dtype,
79+
seq_len_dtype,
80+
seq_len_dtype,
81+
seq_len_dtype,
82+
seq_len_dtype};
83+
}
84+
85+
PD_BUILD_OP(get_padding_offset_v2)
86+
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
87+
.Outputs({"x_remove_padding",
88+
"cum_offsets_out",
89+
"padding_offset",
90+
"cu_seqlens_q",
91+
"cu_seqlens_k"})
92+
.SetKernelFn(PD_KERNEL(GetPaddingOffset))
93+
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape))
94+
.SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype));
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <paddle/phi/backends/xpu/xpu_context.h>
16+
#include "paddle/extension.h"
17+
#include "paddle/phi/core/enforce.h"
18+
#include "xpu/plugin.h"
19+
20+
void TokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
21+
const paddle::Tensor& logits,
22+
const paddle::Tensor& penalty_scores,
23+
const paddle::Tensor& frequency_scores,
24+
const paddle::Tensor& presence_scores,
25+
const paddle::Tensor& temperatures,
26+
const paddle::Tensor& bad_tokens,
27+
const paddle::Tensor& cur_len,
28+
const paddle::Tensor& min_len,
29+
const paddle::Tensor& eos_token_id) {
30+
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
31+
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
32+
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
33+
int64_t bs = logits.shape()[0];
34+
PADDLE_ENFORCE_LE(
35+
bs,
36+
640,
37+
phi::errors::InvalidArgument(
38+
"Only support bsz <= 1024, but received bsz is %d", bs));
39+
int64_t length = logits.shape()[1];
40+
int64_t length_id = pre_ids.shape()[1];
41+
int64_t length_bad_words = bad_tokens.shape()[0];
42+
int64_t end_length = eos_token_id.shape()[0];
43+
switch (logits.type()) {
44+
case paddle::DataType::FLOAT16: {
45+
using XPUType = typename XPUTypeTrait<float16>::Type;
46+
typedef paddle::float16 data_t;
47+
int r = baidu::xpu::api::plugin::token_penalty_multi_scores(
48+
xpu_ctx->x_context(),
49+
pre_ids.data<int64_t>(),
50+
reinterpret_cast<XPUType*>(
51+
const_cast<data_t*>(logits.data<data_t>())),
52+
reinterpret_cast<const XPUType*>(penalty_scores.data<data_t>()),
53+
reinterpret_cast<const XPUType*>(frequency_scores.data<data_t>()),
54+
reinterpret_cast<const XPUType*>(presence_scores.data<data_t>()),
55+
temperatures.data<float>(),
56+
cur_len.data<int64_t>(),
57+
min_len.data<int64_t>(),
58+
eos_token_id.data<int64_t>(),
59+
bad_tokens.data<int64_t>(),
60+
bs,
61+
length,
62+
length_id,
63+
end_length,
64+
length_bad_words);
65+
PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed.");
66+
} break;
67+
case paddle::DataType::FLOAT32: {
68+
int r = baidu::xpu::api::plugin::token_penalty_multi_scores(
69+
xpu_ctx->x_context(),
70+
pre_ids.data<int64_t>(),
71+
const_cast<float*>(logits.data<float>()),
72+
penalty_scores.data<float>(),
73+
frequency_scores.data<float>(),
74+
presence_scores.data<float>(),
75+
temperatures.data<float>(),
76+
cur_len.data<int64_t>(),
77+
min_len.data<int64_t>(),
78+
eos_token_id.data<int64_t>(),
79+
bad_tokens.data<int64_t>(),
80+
bs,
81+
length,
82+
length_id,
83+
end_length,
84+
length_bad_words);
85+
PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed.");
86+
} break;
87+
default:
88+
PD_THROW(
89+
"NOT supported data type. "
90+
"Only float16 and float32 are supported. ");
91+
break;
92+
}
93+
}
94+
95+
PD_BUILD_OP(get_token_penalty_multi_scores_v2)
96+
.Inputs({"pre_ids",
97+
"logits",
98+
"penalty_scores",
99+
"frequency_scores",
100+
"presence_scores",
101+
"temperatures",
102+
"bad_tokens",
103+
"cur_len",
104+
"min_len",
105+
"eos_token_id"})
106+
.Outputs({"logits_out"})
107+
.SetInplaceMap({{"logits", "logits_out"}})
108+
.SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores));

0 commit comments

Comments
 (0)