Skip to content

Commit 5ba7a94

Browse files
authored
[xpu] add xpu custom ops support for llama2-7b (#8515)
1 parent 2b557e2 commit 5ba7a94

37 files changed

+4014
-0
lines changed

csrc/xpu/README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# ernie-bot-custom-ops
2+
ernie bot 昆仑自定义算子库。
3+
4+
## 快速开始
5+
# 构建 XDNN plugin 和 Paddle 自定义算子库
6+
```
7+
$ cd src
8+
$ wget https://baidu-kunlun-product.su.bcebos.com/KL-SDK/klsdk-dev/20240429/xdnn-ubuntu_x86_64.tar.gz
9+
$ wget https://baidu-kunlun-product.su.bcebos.com/KL-SDK/klsdk-dev/20240429/xre-ubuntu_x86_64.tar.gz
10+
$ wget -q --no-check-certificate https://klx-sdk-release-public.su.bcebos.com/xtdk_llvm15/dev/2.7.98.2/xtdk-llvm15-ubuntu1604_x86_64.tar.gz
11+
$ tar -xf xdnn-ubuntu_x86_64.tar.gz
12+
$ tar -xf xre-ubuntu_x86_64.tar.gz
13+
$ tar -xf xtdk-llvm15-ubuntu1604_x86_64.tar.gz
14+
$ export PWD=$(pwd)
15+
$ export XDNN_PATH=${PWD}/xdnn-ubuntu_x86_64/
16+
$ export XRE_PATH=${PWD}/xre-ubuntu_x86_64/
17+
$ export CLANG_PATH=${PWD}/xtdk-llvm15-ubuntu1604_x86_64/
18+
$ bash ./cmake_build.sh
19+
```
20+
21+
## 测试
22+
# 运行 add2 单测
23+
```
24+
$ cd test/python
25+
$ python test_get_padding_offset_v2.py
26+
```
27+
28+
## 如何贡献
29+
```
30+
$ pip install pre-commit==2.17.0
31+
$ pre-commit install
32+
```

csrc/xpu/src/cmake_build.sh

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/bin/bash
2+
3+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
set -e
18+
19+
# export XDNN_PATH=Paddle/build/third_party/xpu/src/extern_xpu/xdnn-ubuntu_x86_64/ # <path_to_xdnn>
20+
# export XRE_PATH=Paddle/build/third_party/xpu/src/extern_xpu/xre-ubuntu_x86_64/ # <path_to_xre>
21+
# export CLANG_PATH=xtdk-ubuntu_1604_x86_64 # <path_to_xtdk>
22+
# export HOST_SYSROOT=/opt/compiler/gcc-8.2/bin/gcc # <path_to_gcc>
23+
24+
cd plugin
25+
./cmake_build.sh
26+
cd -
27+
28+
python -m pip uninstall paddlenlp_ops -y
29+
python setup.py install

csrc/xpu/src/get_padding_offset_v2.cc

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <paddle/phi/backends/xpu/xpu_context.h>
16+
#include "paddle/extension.h"
17+
#include "xpu/plugin.h"
18+
19+
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
20+
const paddle::Tensor& cum_offsets,
21+
const paddle::Tensor& token_num,
22+
const paddle::Tensor& seq_len) {
23+
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
24+
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
25+
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
26+
27+
std::vector<int64_t> input_ids_shape = input_ids.shape();
28+
const int bsz = seq_len.shape()[0];
29+
const int seq_length = input_ids_shape[1];
30+
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
31+
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
32+
33+
34+
const int token_num_data = cpu_token_num.data<int64_t>()[0];
35+
auto x_remove_padding = paddle::full(
36+
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
37+
auto padding_offset = paddle::full(
38+
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
39+
auto cu_seqlens_q =
40+
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
41+
auto cu_seqlens_k =
42+
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
43+
int r = baidu::xpu::api::plugin::get_padding_offset(
44+
xpu_ctx->x_context(),
45+
padding_offset.data<int>(),
46+
cum_offsets_out.data<int>(),
47+
cu_seqlens_q.data<int>(),
48+
cu_seqlens_k.data<int>(),
49+
x_remove_padding.data<int64_t>(),
50+
input_ids.data<int64_t>(),
51+
cum_offsets.data<int>(),
52+
seq_len.data<int>(),
53+
seq_length,
54+
bsz);
55+
PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed.");
56+
return {x_remove_padding,
57+
cum_offsets_out,
58+
padding_offset,
59+
cu_seqlens_q,
60+
cu_seqlens_k};
61+
}
62+
63+
std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
64+
const std::vector<int64_t>& input_ids_shape,
65+
const std::vector<int64_t>& cum_offsets_shape,
66+
const std::vector<int64_t>& token_num_shape,
67+
const std::vector<int64_t>& seq_len_shape) {
68+
int64_t bsz = seq_len_shape[0];
69+
int64_t seq_len = input_ids_shape[1];
70+
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
71+
}
72+
73+
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
74+
const paddle::DataType& input_ids_dtype,
75+
const paddle::DataType& cum_offsets_dtype,
76+
const paddle::DataType& token_num_dtype,
77+
const paddle::DataType& seq_len_dtype) {
78+
return {input_ids_dtype,
79+
seq_len_dtype,
80+
seq_len_dtype,
81+
seq_len_dtype,
82+
seq_len_dtype};
83+
}
84+
85+
PD_BUILD_OP(get_padding_offset_v2)
86+
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
87+
.Outputs({"x_remove_padding",
88+
"cum_offsets_out",
89+
"padding_offset",
90+
"cu_seqlens_q",
91+
"cu_seqlens_k"})
92+
.SetKernelFn(PD_KERNEL(GetPaddingOffset))
93+
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape))
94+
.SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype));
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <paddle/phi/backends/xpu/xpu_context.h>
16+
#include "paddle/extension.h"
17+
#include "paddle/phi/core/enforce.h"
18+
#include "xpu/plugin.h"
19+
20+
void TokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
21+
const paddle::Tensor& logits,
22+
const paddle::Tensor& penalty_scores,
23+
const paddle::Tensor& frequency_scores,
24+
const paddle::Tensor& presence_scores,
25+
const paddle::Tensor& temperatures,
26+
const paddle::Tensor& bad_tokens,
27+
const paddle::Tensor& cur_len,
28+
const paddle::Tensor& min_len,
29+
const paddle::Tensor& eos_token_id) {
30+
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
31+
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
32+
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
33+
int64_t bs = logits.shape()[0];
34+
PADDLE_ENFORCE_LE(
35+
bs,
36+
640,
37+
phi::errors::InvalidArgument(
38+
"Only support bsz <= 1024, but received bsz is %d", bs));
39+
int64_t length = logits.shape()[1];
40+
int64_t length_id = pre_ids.shape()[1];
41+
int64_t length_bad_words = bad_tokens.shape()[0];
42+
int64_t end_length = eos_token_id.shape()[0];
43+
switch (logits.type()) {
44+
case paddle::DataType::FLOAT16: {
45+
using XPUType = typename XPUTypeTrait<float16>::Type;
46+
typedef paddle::float16 data_t;
47+
int r = baidu::xpu::api::plugin::token_penalty_multi_scores(
48+
xpu_ctx->x_context(),
49+
pre_ids.data<int64_t>(),
50+
reinterpret_cast<XPUType*>(
51+
const_cast<data_t*>(logits.data<data_t>())),
52+
reinterpret_cast<const XPUType*>(penalty_scores.data<data_t>()),
53+
reinterpret_cast<const XPUType*>(frequency_scores.data<data_t>()),
54+
reinterpret_cast<const XPUType*>(presence_scores.data<data_t>()),
55+
temperatures.data<float>(),
56+
cur_len.data<int64_t>(),
57+
min_len.data<int64_t>(),
58+
eos_token_id.data<int64_t>(),
59+
bad_tokens.data<int64_t>(),
60+
bs,
61+
length,
62+
length_id,
63+
end_length,
64+
length_bad_words);
65+
PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed.");
66+
} break;
67+
case paddle::DataType::FLOAT32: {
68+
int r = baidu::xpu::api::plugin::token_penalty_multi_scores(
69+
xpu_ctx->x_context(),
70+
pre_ids.data<int64_t>(),
71+
const_cast<float*>(logits.data<float>()),
72+
penalty_scores.data<float>(),
73+
frequency_scores.data<float>(),
74+
presence_scores.data<float>(),
75+
temperatures.data<float>(),
76+
cur_len.data<int64_t>(),
77+
min_len.data<int64_t>(),
78+
eos_token_id.data<int64_t>(),
79+
bad_tokens.data<int64_t>(),
80+
bs,
81+
length,
82+
length_id,
83+
end_length,
84+
length_bad_words);
85+
PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed.");
86+
} break;
87+
default:
88+
PD_THROW(
89+
"NOT supported data type. "
90+
"Only float16 and float32 are supported. ");
91+
break;
92+
}
93+
}
94+
95+
PD_BUILD_OP(get_token_penalty_multi_scores_v2)
96+
.Inputs({"pre_ids",
97+
"logits",
98+
"penalty_scores",
99+
"frequency_scores",
100+
"presence_scores",
101+
"temperatures",
102+
"bad_tokens",
103+
"cur_len",
104+
"min_len",
105+
"eos_token_id"})
106+
.Outputs({"logits_out"})
107+
.SetInplaceMap({{"logits", "logits_out"}})
108+
.SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores));

0 commit comments

Comments
 (0)