Skip to content

Commit d96f9bb

Browse files
author
395822456@qq.com
committed
update docs and shell
1 parent f5c1503 commit d96f9bb

File tree

11 files changed

+715
-35
lines changed

11 files changed

+715
-35
lines changed

csrc/cpu/0001-patch-fp16-and-bf16.patch

Lines changed: 280 additions & 0 deletions
Large diffs are not rendered by default.

csrc/cpu/0001-patch-fp32.patch

Lines changed: 302 additions & 0 deletions
Large diffs are not rendered by default.

csrc/cpu/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# cpu-custom-ops
2+
3+
## 快速开始
4+
# 构建 cpu 自定义算子库
5+
```
6+
$ 前提条件:机器支持avx指令
7+
$ cd src
8+
$ bash setup.sh
9+
```

csrc/cpu/setup.sh

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
#1. download XFT
16+
if [ ! -d xFasterTransformer]; then
17+
git clone --branch v1.7.2 https://github.com/intel/xFasterTransformer.git
18+
fi
19+
20+
#2.cp patch
21+
cd xFasterTransformer
22+
git checkout .
23+
cd ..
24+
25+
if lscpu | grep -q "avx512_bf16"; then
26+
echo "apply bf16 and fp16."
27+
if [ ! -f 0001-patch-fp16-and-bf16.patch ]; then
28+
echo "Error: 0001-patch-fp16-and-bf16.patch not exist."
29+
exit 1
30+
fi
31+
# apply patch
32+
cp ./0001-patch-fp16-and-bf16.patch ./xFasterTransformer/paddle.patch
33+
else
34+
echo "apply fp32 "
35+
if [ ! -f 0001-patch-fp32.patch ]; then
36+
echo "Error: does 0001-patch-fp32.patch not exist."
37+
exit 1
38+
fi
39+
cp ./0001-patch-fp32.patch ./xFasterTransformer/paddle.patch
40+
fi
41+
42+
#3. apply patch
43+
cd xFasterTransformer
44+
git apply paddle.patch
45+
46+
#4. build xFasterTransformer
47+
sh ./3rdparty/prepare_oneccl.sh
48+
source ./3rdparty/oneccl/build/_install/env/setvars.sh
49+
source /workspace/cpu_repo/xFasterTransformer/3rdparty/oneccl/build/_install/env/setvars.sh
50+
51+
rm -rf build
52+
mkdir build && cd build
53+
cmake ..
54+
make -j
55+
56+
#xft
57+
export XFT_HEADER_DIR=$PWD
58+
export XFT_LIB_DIR=$XFT_HEADER_DIR/build
59+
export LD_LIBRARY_PATH=$XFT_LIB_DIR:$LD_LIBRARY_PATH
60+
61+
#setup cpu paddle_nlp ops
62+
cd ..
63+
python ./src/setup_cpu.py install

csrc/cpu/src/set_value_by_flags.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@ void set_value_by_flag_and_id(const bool *stop_flags, int64_t *pre_ids_all, cons
2727

2828
std::vector<paddle::Tensor> SetValueByFlagsAndIdx(const paddle::Tensor& pre_ids_all, const paddle::Tensor& pre_ids_now, const paddle::Tensor& step_idx, const paddle::Tensor& stop_flags) {
2929
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
30-
auto stop_flags_out = stop_flags.copy_to(stop_flags.place(), false); // gpu -> gpu
30+
auto stop_flags_out = stop_flags.copy_to(stop_flags.place(), false);
3131

3232
int bs = stop_flags.shape()[0];
33-
// max_len的长度
3433
int length = pre_ids_all_shape[1];
3534

3635
set_value_by_flag_and_id(stop_flags.data<bool>(), const_cast<int64_t*>(pre_ids_all.data<int64_t>()), pre_ids_now.data<int64_t>(), step_idx.data<int64_t>(), bs, length);

csrc/cpu/src/setup.py renamed to csrc/cpu/src/setup_cpu.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import os
1616
import site
17+
import subprocess
1718

1819
from paddle.utils.cpp_extension import CppExtension, setup
1920

@@ -32,21 +33,45 @@ def build_extensions(self):
3233
super().build_extensions()
3334

3435

36+
def check_avx512_bf16__support():
37+
try:
38+
result = subprocess.run(
39+
["lscpu", "|", "grep", '"avx512_bf16"'],
40+
stdout=subprocess.PIPE,
41+
stderr=subprocess.PIPE,
42+
text=True,
43+
shell=True,
44+
)
45+
46+
if "avx512_bf16" in result.stdout.lower():
47+
return True
48+
else:
49+
return False
50+
51+
except Exception as e:
52+
print(f"Error checking AVX512 support: {e}")
53+
return False
54+
55+
3556
# cc flags
3657
paddle_extra_compile_args = [
3758
"-std=c++17",
3859
"-shared",
3960
"-fPIC",
4061
"-Wno-parentheses",
4162
"-DPADDLE_WITH_CUSTOM_KERNEL",
42-
"-DAVX512_FP32_WEIGHT_ONLY_FP16=true",
43-
"-DAVX512_FP32_WEIGHT_ONLY_INT8=true",
44-
# bf16 机器
45-
# "-DAVX512_BF16_WEIGHT_ONLY_BF16=true",
46-
# "-DAVX512_BF16_WEIGHT_ONLY_BF16=true",
4763
]
4864

49-
65+
if check_avx512_bf16__support():
66+
paddle_extra_compile_args += [
67+
"-DAVX512_BF16_WEIGHT_ONLY_BF16=true",
68+
"-DAVX512_BF16_WEIGHT_ONLY_BF16=true",
69+
]
70+
else:
71+
paddle_extra_compile_args += [
72+
"-DAVX512_FP32_WEIGHT_ONLY_FP16=true",
73+
"-DAVX512_FP32_WEIGHT_ONLY_INT8=true",
74+
]
5075
# include path
5176
site_packages_path = site.getsitepackages()
5277
paddle_custom_kernel_include = [os.path.join(path, "paddle", "include") for path in site_packages_path]
@@ -55,10 +80,7 @@ def build_extensions(self):
5580
XFT_LIBRARY_DIR = os.environ["XFT_LIB_DIR"]
5681

5782
# include path third_party
58-
compile_third_party_path = os.path.join(os.environ["PADDLE_BINARY_DIR"], "third_party")
5983
paddle_custom_kernel_include += [
60-
os.path.join(compile_third_party_path, "install/gflags/include"), # gflags
61-
os.path.join(compile_third_party_path, "install/glog/include"), # glog
6284
os.path.join(XFT_INCLUDE_DIR, "include"), # glog
6385
os.path.join(XFT_INCLUDE_DIR, "src/common"), # src
6486
os.path.join(XFT_INCLUDE_DIR, "src/kernel"), # src
@@ -79,11 +101,11 @@ def build_extensions(self):
79101

80102
custom_kernel_dot_module = CppExtension(
81103
sources=[
82-
"xft_llama_layer.cc",
83-
"../../generation/save_with_output.cc",
84-
"token_penalty_multi_scores.cc",
85-
"stop_generation_multi_ends.cc",
86-
"set_value_by_flags.cc",
104+
"./src/xft_llama_layer.cc",
105+
"../generation/save_with_output.cc",
106+
"./src/token_penalty_multi_scores.cc",
107+
"./src/stop_generation_multi_ends.cc",
108+
"./src/set_value_by_flags.cc",
87109
],
88110
include_dirs=paddle_custom_kernel_include,
89111
library_dirs=paddle_custom_kernel_library_dir,

csrc/cpu/src/xft_llama_layer.cc

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ std::vector<paddle::Tensor> InvokeLLaMALayer(
4444
auto out = paddle::empty_like(input);
4545
auto batchSize = input.shape()[0];
4646
auto inputSeqLen = input.shape()[1];
47-
auto past_seq_len=pastSeqLen.data<int64_t>()[0];
48-
auto cur_seq_len=currentSeqLen.data<int64_t>()[0];
49-
auto step_id=step.data<int64_t>()[0];
47+
auto past_seq_len = pastSeqLen.data<int64_t>()[0];
48+
auto cur_seq_len = currentSeqLen.data<int64_t>()[0];
49+
auto step_id = step.data<int64_t>()[0];
5050
auto input_ptr = reinterpret_cast<const void *>(input.data<float>());
5151
auto ln1Gamma_ptr = reinterpret_cast<const float *>(ln1Gamma.data<float>());
5252
auto qkvWeight_ptr = reinterpret_cast<const void *>(qkvWeight.data<float>());
@@ -64,12 +64,16 @@ std::vector<paddle::Tensor> InvokeLLaMALayer(
6464
xft_data_type = xft::DataType::bf16;
6565
}
6666
auto xft_act_type = xft::ActivationType::SILU;
67-
if (activation == "silu") {
68-
xft_act_type = xft::ActivationType::SILU;
67+
if (activation == "relu") {
68+
xft_act_type = xft::ActivationType::RELU;
69+
} else if (activation == "gelu") {
70+
xft_act_type = xft::ActivationType::GELU;
71+
} else if (activation == "swiglu") {
72+
xft_act_type = xft::ActivationType::SWIGLU;
6973
}
7074
auto xft_norm_type = xft::NormType::RMS;
71-
if (normType == "rmsnorm") {
72-
xft_norm_type = xft::NormType::RMS;
75+
if (normType == "layernorm") {
76+
xft_norm_type = xft::NormType::LN;
7377
}
7478
invokeLayerLLaMA(xft_data_type,
7579
xft_act_type,

llm/docs/inference.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ cd ./paddlenlp/csrc/xpu/src && sh cmake_build.sh
9797
# 动态图模型推理命令参考
9898
python ./predict/predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16
9999

100+
#Cpu设备使用avx指令动态图推理参考
101+
python ./predict/predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float32 --avx_mode --avx_type "fp16" --device "cpu"
102+
100103
# PrefixTuning动态图推理参考
101104
python ./predict/predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --export_precache true --prefix_path ./checkpoints/llama_prefix_ckpts
102105

@@ -117,6 +120,9 @@ python ./predict/predictor.py --model_name_or_path checkpoints/llama_ptq_ckpts -
117120
# 动转静命令参考
118121
python ./predict/export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16
119122

123+
# Cpu动转静avx指令动转静参考
124+
python ./predict/export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --avx_mode --avx_type "fp16" --device "cpu"
125+
120126
# PrefixTuning动转静命令参考
121127
python ./predict/export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --export_precache true
122128

@@ -137,6 +143,9 @@ python ./predict/export_model.py --model_name_or_path checkpoints/llama_ptq_ckpt
137143
# 静态图推理命令参考
138144
python ./predict/predictor.py --model_name_or_path ./inference --inference_model --quant_type weight_only_int8 --dtype "float16" --mode "static"
139145

146+
#Cpu使用avx指令静态图推理参考
147+
python ./predict/predictor.py --model_name_or_path ./inference --inference_model --avx_mode --avx_type "fp16" --dtype "float32" --mode "static" --device "cpu"
148+
140149
# PrefixTuning静态图推理命令参考
141150
python ./predict/predictor.py --model_name_or_path ./inference --inference_model --quant_type weight_only_int8 --dtype "float16" --mode "static" --export_precache true --prefix_path ./checkpoints/llama_prefix_ckpts
142151

llm/predict/predictor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,9 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
401401
self.arange_tensor_encoder = paddle.arange(config.total_max_length, dtype=self.dtype)
402402

403403
if config.device == "cpu" and config.avx_model:
404+
assert (
405+
"llama" in self.architectures and self.model_config.model_type != "llama-img2txt"
406+
), "avx_mode only support llama now"
404407
self.cache_kvs = None
405408
self.attention_mask = None
406409
self.tgt_generation_mask = None

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,11 +1093,7 @@ def init_weight_shape(self, config):
10931093
self.gate_weight_shape = [self.embed_dim, self.dim_feedforward]
10941094
self.up_weight_shape = [self.embed_dim, self.dim_feedforward]
10951095
self.down_weight_shape = [self.dim_feedforward, self.embed_dim]
1096-
self.qkv_weight_shape = (
1097-
[self.embed_dim, (self.num_heads + 2 * self.kv_num_heads) * self.head_dim]
1098-
if config.trans_qkvw
1099-
else [self.embed_dim, (self.num_heads + 2 * self.kv_num_heads) * self.head_dim]
1100-
)
1096+
self.qkv_weight_shape = [self.embed_dim, (self.num_heads + 2 * self.kv_num_heads) * self.head_dim]
11011097
self.linear_weight_shape = [self.num_heads * self.head_dim, self.embed_dim]
11021098
self.ffn1_weight_shape = (
11031099
[self.embed_dim, self.dim_feedforward * 2]

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,6 @@ def __init__(self, config: LlamaConfig):
161161
def set_transformer_block(self, transformer_config, max_position_embeddings, compute_type):
162162
self.transformer_block = FusedMultiTransformerAvx(transformer_config, max_position_embeddings, compute_type)
163163

164-
def remove_padding(self, input_ids, seq_lens_this_time):
165-
pass
166-
167-
# This function is a little different from prepare_input_ids_for_generation in paddlenlp/transformers/generation/utils.py
168164
@staticmethod
169165
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
170166
batch_size = 1
@@ -193,7 +189,6 @@ def forward(
193189
raise ValueError("You have to specify either input_ids or inputs_embeds")
194190

195191
# genereate a fake input_ids according to inputs_embeds
196-
# this is usually occurred in img2txt multimodal model when first enter into this forward function.
197192
if input_ids is None and inputs_embeds is not None:
198193
input_ids = self.prepare_input_ids_for_generation(self.config.bos_token_id, inputs_embeds)
199194
if inputs_embeds is not None:
@@ -295,13 +290,11 @@ def set_state_dict(self, state_dict):
295290
concated_ffn1_weight = np.concatenate(
296291
[unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1
297292
)
298-
# ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight)
299293
gate_up_list = split_fn(concated_ffn1_weight)
300294
gate_weight_tensor = paddle.to_tensor(gate_up_list[0])
301295
up_weight_tensor = paddle.to_tensor(gate_up_list[1])
302296

303297
qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight)
304-
# 不需要离线处理量化权重
305298
self.transformer_block.qkv_weights[idx].set_value(
306299
qkv_weight_tensor.cast(self.transformer_block.qkv_weights[idx].dtype)
307300
)

0 commit comments

Comments
 (0)