Skip to content

Commit b791375

Browse files
committed
support llama avx model inference
1 parent 007b653 commit b791375

15 files changed

+2141
-77
lines changed

csrc/cpu/0001-patch-fp16-and-bf16.patch

Lines changed: 280 additions & 0 deletions
Large diffs are not rendered by default.

csrc/cpu/0001-patch-fp32.patch

Lines changed: 302 additions & 0 deletions
Large diffs are not rendered by default.

csrc/cpu/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# cpu-custom-ops
2+
3+
## 快速开始
4+
# 构建 cpu 自定义算子库
5+
```
6+
$ 前提条件:机器支持avx指令
7+
$ bash setup.sh
8+
```

csrc/cpu/setup.sh

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
#1. download XFT
16+
if [ ! -d xFasterTransformer]; then
17+
git clone --branch v1.7.2 https://github.com/intel/xFasterTransformer.git
18+
fi
19+
20+
#2.cp patch
21+
cd xFasterTransformer
22+
git checkout .
23+
cd ..
24+
25+
if lscpu | grep -q "avx512_bf16"; then
26+
echo "apply bf16 and fp16."
27+
if [ ! -f 0001-patch-fp16-and-bf16.patch ]; then
28+
echo "Error: 0001-patch-fp16-and-bf16.patch not exist."
29+
exit 1
30+
fi
31+
# apply patch
32+
cp ./0001-patch-fp16-and-bf16.patch ./xFasterTransformer/paddle.patch
33+
else
34+
echo "apply fp32 "
35+
if [ ! -f 0001-patch-fp32.patch ]; then
36+
echo "Error: does 0001-patch-fp32.patch not exist."
37+
exit 1
38+
fi
39+
cp ./0001-patch-fp32.patch ./xFasterTransformer/paddle.patch
40+
fi
41+
42+
#3. apply patch
43+
cd xFasterTransformer
44+
git apply paddle.patch
45+
46+
# #4. build xFasterTransformer
47+
sh ./3rdparty/prepare_oneccl.sh
48+
source ./3rdparty/oneccl/build/_install/env/setvars.sh
49+
50+
rm -rf build
51+
mkdir build && cd build
52+
cmake ..
53+
make -j
54+
55+
#xft
56+
export XFT_HEADER_DIR=$PWD
57+
export XFT_LIB_DIR=$XFT_HEADER_DIR/build
58+
export LD_LIBRARY_PATH=$XFT_LIB_DIR:$LD_LIBRARY_PATH
59+
60+
#setup cpu paddle_nlp ops
61+
cd ..
62+
python ./src/setup_cpu.py install

csrc/cpu/src/set_value_by_flags.cc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/extension.h"
16+
17+
void set_value_by_flag_and_id(const bool *stop_flags, int64_t *pre_ids_all, const int64_t *pre_ids, const int64_t *step_idx, int bs, int length) {
18+
for (int bi=0;bi<bs;bi++){
19+
if(!stop_flags[bi]){
20+
int64_t *pre_ids_all_now = pre_ids_all + bi * length;
21+
if (step_idx[bi] >= 0) {
22+
pre_ids_all_now[step_idx[bi]] = pre_ids[bi];
23+
}
24+
}
25+
}
26+
}
27+
28+
std::vector<paddle::Tensor> SetValueByFlagsAndIdx(const paddle::Tensor& pre_ids_all, const paddle::Tensor& pre_ids_now, const paddle::Tensor& step_idx, const paddle::Tensor& stop_flags) {
29+
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
30+
auto stop_flags_out = stop_flags.copy_to(stop_flags.place(), false);
31+
32+
int bs = stop_flags.shape()[0];
33+
int length = pre_ids_all_shape[1];
34+
35+
set_value_by_flag_and_id(stop_flags.data<bool>(), const_cast<int64_t*>(pre_ids_all.data<int64_t>()), pre_ids_now.data<int64_t>(), step_idx.data<int64_t>(), bs, length);
36+
return {stop_flags_out};
37+
}
38+
39+
std::vector<std::vector<int64_t>> SetValueByFlagsAndIdxInferShape(const std::vector<int64_t>& pre_ids_all_shape, const std::vector<int64_t>& pre_ids_now_shape,
40+
const std::vector<int64_t>& step_idx_shape, const std::vector<int64_t>& stop_flags_shape) {
41+
return {stop_flags_shape};
42+
}
43+
44+
std::vector<paddle::DataType> SetValueByFlagsAndIdxInferDtype(const paddle::DataType& pre_ids_all_dtype,
45+
const paddle::DataType& pre_ids_now_dtype,
46+
const paddle::DataType& step_idx_dtype,
47+
const paddle::DataType& stop_flags_dtype) {
48+
return {stop_flags_dtype};
49+
}
50+
51+
PD_BUILD_OP(set_value_by_flags_and_idx)
52+
.Inputs({"pre_ids_all", "pre_ids_now", "step_idx", "stop_flags"})
53+
.Outputs({"stop_flags_out"})
54+
.SetKernelFn(PD_KERNEL(SetValueByFlagsAndIdx))
55+
.SetInferShapeFn(PD_INFER_SHAPE(SetValueByFlagsAndIdxInferShape))
56+
.SetInferDtypeFn(PD_INFER_DTYPE(SetValueByFlagsAndIdxInferDtype));

csrc/cpu/src/setup_cpu.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import site
17+
import subprocess
18+
19+
from paddle.utils.cpp_extension import CppExtension, setup
20+
21+
# from setuptools import Extension, setup
22+
from setuptools.command.build_ext import build_ext
23+
24+
25+
# refer: https://note.qidong.name/2018/03/setup-warning-strict-prototypes
26+
# Avoid a gcc warning below:
27+
# cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid
28+
# for C/ObjC but not for C++
29+
class BuildExt(build_ext):
30+
def build_extensions(self):
31+
if "-Wstrict-prototypes" in self.compiler.compiler_so:
32+
self.compiler.compiler_so.remove("-Wstrict-prototypes")
33+
super().build_extensions()
34+
35+
36+
def check_avx512_bf16__support():
37+
try:
38+
result = subprocess.run(
39+
["lscpu", "|", "grep", '"avx512_bf16"'],
40+
stdout=subprocess.PIPE,
41+
stderr=subprocess.PIPE,
42+
text=True,
43+
shell=True,
44+
)
45+
46+
if "avx512_bf16" in result.stdout.lower():
47+
return True
48+
else:
49+
return False
50+
51+
except Exception as e:
52+
print(f"Error checking AVX512 support: {e}")
53+
return False
54+
55+
56+
# cc flags
57+
paddle_extra_compile_args = [
58+
"-std=c++17",
59+
"-shared",
60+
"-fPIC",
61+
"-Wno-parentheses",
62+
"-DPADDLE_WITH_CUSTOM_KERNEL",
63+
]
64+
65+
if check_avx512_bf16__support():
66+
paddle_extra_compile_args += [
67+
"-DAVX512_BF16_WEIGHT_ONLY_BF16=true",
68+
"-DAVX512_BF16_WEIGHT_ONLY_BF16=true",
69+
]
70+
else:
71+
paddle_extra_compile_args += [
72+
"-DAVX512_FP32_WEIGHT_ONLY_FP16=true",
73+
"-DAVX512_FP32_WEIGHT_ONLY_INT8=true",
74+
]
75+
# include path
76+
site_packages_path = site.getsitepackages()
77+
paddle_custom_kernel_include = [os.path.join(path, "paddle", "include") for path in site_packages_path]
78+
79+
XFT_INCLUDE_DIR = os.environ["XFT_HEADER_DIR"]
80+
XFT_LIBRARY_DIR = os.environ["XFT_LIB_DIR"]
81+
82+
# include path third_party
83+
paddle_custom_kernel_include += [
84+
os.path.join(XFT_INCLUDE_DIR, "include"), # glog
85+
os.path.join(XFT_INCLUDE_DIR, "src/common"), # src
86+
os.path.join(XFT_INCLUDE_DIR, "src/kernel"), # src
87+
os.path.join(XFT_INCLUDE_DIR, "src/layers"), # src
88+
os.path.join(XFT_INCLUDE_DIR, "src/models"), # src
89+
os.path.join(XFT_INCLUDE_DIR, "src/utils"), # src
90+
os.path.join(XFT_INCLUDE_DIR, "3rdparty/onednn/include"), # src
91+
os.path.join(XFT_INCLUDE_DIR, "3rdparty/onednn/build/include"), # src
92+
os.path.join(XFT_INCLUDE_DIR, "3rdparty/xdnn"), # src
93+
]
94+
95+
# libs path
96+
paddle_custom_kernel_library_dir = [os.path.join(path, "paddle", "base") for path in site_packages_path]
97+
paddle_custom_kernel_library_dir += [XFT_LIBRARY_DIR]
98+
99+
100+
libs = [":libxfastertransformer.so", ":libxft_comm_helper.so"]
101+
102+
custom_kernel_dot_module = CppExtension(
103+
sources=[
104+
"./src/xft_llama_layer.cc",
105+
"../generation/save_with_output.cc",
106+
"./src/token_penalty_multi_scores.cc",
107+
"./src/stop_generation_multi_ends.cc",
108+
"./src/set_value_by_flags.cc",
109+
],
110+
include_dirs=paddle_custom_kernel_include,
111+
library_dirs=paddle_custom_kernel_library_dir,
112+
libraries=libs,
113+
extra_compile_args=paddle_extra_compile_args,
114+
)
115+
116+
setup(
117+
name="paddlenlp_ops",
118+
version="1.0",
119+
description="custom kernel fot compiling",
120+
ext_modules=[custom_kernel_dot_module],
121+
)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <stdlib.h>
16+
#include <string.h>
17+
18+
#include "paddle/extension.h"
19+
#include <stdio.h>
20+
21+
22+
bool is_in_end(const int64_t id, const int64_t* end_ids, int length) {
23+
bool flag = false;
24+
for (int i = 0; i < length; i++) {
25+
if (id == end_ids[i]) {
26+
return true;
27+
}
28+
}
29+
return flag;
30+
}
31+
32+
void set_value_by_flags(const bool* stop_flags,
33+
const int64_t* end_ids,
34+
int64_t* topk_ids,
35+
bool* stop_flags_out,
36+
const int bs,
37+
int end_length) {
38+
for (int bi = 0; bi < bs; bi++) {
39+
topk_ids[bi] = stop_flags[bi] ? end_ids[0] : topk_ids[bi];
40+
if (is_in_end(topk_ids[bi], end_ids, end_length)) {
41+
stop_flags_out[bi] = true;
42+
}
43+
}
44+
}
45+
46+
47+
std::vector<paddle::Tensor> GetStopFlagsMulti(const paddle::Tensor& topk_ids,
48+
const paddle::Tensor& stop_flags,
49+
const paddle::Tensor& end_ids) {
50+
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
51+
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
52+
53+
std::vector<int64_t> shape = topk_ids.shape();
54+
int64_t bs_now = shape[0];
55+
int64_t end_length = end_ids.shape()[0];
56+
auto topk_ids_out = topk_ids.copy_to(topk_ids.place(), false);
57+
auto stop_flags_out = stop_flags.copy_to(stop_flags.place(), false);
58+
set_value_by_flags(stop_flags.data<bool>(),
59+
end_ids.data<int64_t>(),
60+
topk_ids_out.data<int64_t>(),
61+
stop_flags_out.data<bool>(),
62+
bs_now,
63+
end_length);
64+
65+
return {topk_ids_out, stop_flags_out};
66+
}
67+
68+
std::vector<std::vector<int64_t>> GetStopFlagsMultiInferShape(
69+
const std::vector<int64_t>& topk_ids_shape,
70+
const std::vector<int64_t>& stop_flags_shape,
71+
const std::vector<int64_t>& end_ids_shape) {
72+
return {topk_ids_shape, stop_flags_shape};
73+
}
74+
75+
std::vector<paddle::DataType> GetStopFlagsMultiInferDtype(
76+
const paddle::DataType& topk_ids_dtype,
77+
const paddle::DataType& stop_flags_dtype,
78+
const paddle::DataType& end_ids_dtype) {
79+
return {topk_ids_dtype, stop_flags_dtype};
80+
}
81+
82+
PD_BUILD_OP(set_stop_value_multi_ends)
83+
.Inputs({"topk_ids", "stop_flags", "end_ids"})
84+
.Outputs({"topk_ids_out", "stop_flags_out"})
85+
.SetKernelFn(PD_KERNEL(GetStopFlagsMulti))
86+
.SetInferShapeFn(PD_INFER_SHAPE(GetStopFlagsMultiInferShape))
87+
.SetInferDtypeFn(PD_INFER_DTYPE(GetStopFlagsMultiInferDtype));

0 commit comments

Comments
 (0)