Skip to content

Commit 3d777c1

Browse files
authored
[XPU] xpu devices support llama-7b basic mode inference (turn on BlockAttention) (#8588)
* xpu devices support llama-7b basic mode inference (turn on BlockAttention)
1 parent 5ba7a94 commit 3d777c1

File tree

11 files changed

+165
-71
lines changed

11 files changed

+165
-71
lines changed

llm/docs/inference.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ PaddleNLP 针对于Transformer 系列编写了高性能自定义算子,提升
8383

8484
```shell
8585
git clone https://github.com/PaddlePaddle/PaddleNLP
86+
#GPU设备安装自定义算子
8687
cd ./paddlenlp/csrc && python setup_cuda.py install
88+
#XPU设备安装自定义算子
89+
cd ./paddlenlp/csrc/xpu/src && sh cmake_build.sh
8790
```
8891

8992
### 2.3 关闭BlockAttention的高性能推理
@@ -163,6 +166,9 @@ python predictor.py --model_name_or_path ./inference --inference_model --quant_
163166
# 动态图模型推理命令参考
164167
python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --block_attn
165168

169+
# XPU设备动态图模型推理命令参考
170+
python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --block_attn --device xpu
171+
166172
# Weight Only Int8 动态图推理参考
167173
python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --quant_type weight_only_int8 --block_attn
168174

@@ -179,6 +185,9 @@ python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_
179185
# 动转静命令参考
180186
python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --block_attn
181187

188+
# XPU设备动转静命令参考
189+
python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --block_attn --device xpu
190+
182191
# Weight Only Int8 动转静命令参考
183192
python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --quant_type weight_only_int8 --block_attn
184193

@@ -194,6 +203,9 @@ python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --infere
194203
# 静态图推理命令参考
195204
python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --block_attn
196205

206+
# XPU设备静态图推理命令参考
207+
python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --block_attn --device xpu
208+
197209
# Weight Only Int8 静态图推理命令参考
198210
python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --quant_type weight_only_int8 --block_attn
199211

llm/predictor.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,11 @@ def _create_predictor(self, predictor_args: PredictorArgument):
650650
if predictor_args.device in paddle.device.get_all_custom_device_type():
651651
device_id = int(os.environ.get("FLAGS_selected_{}s".format(predictor_args.device), 0))
652652
config.enable_custom_device(predictor_args.device, device_id)
653+
elif predictor_args.device == "xpu":
654+
raise ValueError(
655+
"you should export xpu static model with --block_attn flag and use predictor with --block_attn too"
656+
"https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/inference.md"
657+
)
653658
else:
654659
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
655660
config.enable_use_gpu(100, device_id)
@@ -920,7 +925,9 @@ def _preprocess(self, source):
920925
source = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in source]
921926

922927
for i, text in enumerate(source):
923-
add_special_tokens = self.tokenizer.chat_template is None or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer))
928+
add_special_tokens = self.tokenizer.chat_template is None or isinstance(
929+
self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)
930+
)
924931
add_special_tokens = add_special_tokens if not self.benchmark else False
925932
tokens = self.tokenizer(
926933
text,
@@ -1076,6 +1083,15 @@ def _create_predictor(self, predictor_args: PredictorArgument):
10761083
if predictor_args.device in paddle.device.get_all_custom_device_type():
10771084
device_id = int(os.environ.get("FLAGS_selected_{}s".format(predictor_args.device), 0))
10781085
config.enable_custom_device(predictor_args.device, device_id)
1086+
elif predictor_args.device == "xpu":
1087+
config.enable_xpu()
1088+
device_id = int(os.environ.get("FLAGS_selected_xpus", 0))
1089+
config.set_xpu_device_id(device_id)
1090+
xpu_config = paddle.inference.XpuConfig()
1091+
xpu_config.device_id = device_id
1092+
xpu_config.l3_size = 63 * 1024 * 1024
1093+
xpu_config.l3_autotune_size = 63 * 1024 * 1024
1094+
config.set_xpu_config(xpu_config)
10791095
else:
10801096
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
10811097
config.enable_use_gpu(100, device_id)
@@ -1331,6 +1347,11 @@ def create_predictor(
13311347
tensor_parallel_rank=tensor_parallel_rank,
13321348
)
13331349
else:
1350+
if predictor_args.device == "xpu":
1351+
raise ValueError(
1352+
"you should run xpu dynamic model with --block_attn flag"
1353+
"https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/inference.md"
1354+
)
13341355
from paddlenlp.experimental.transformers import (
13351356
LlamaForCausalLMInferenceModel as LlamaInferenceModel,
13361357
)
@@ -1588,7 +1609,9 @@ def predict():
15881609

15891610
def benchmark(predictor, predictor_args, model_args):
15901611
# Just construct a simple benchmark input. We pad input to the src_length.
1591-
benchmark_texts = [predictor.tokenizer.pad_token * predictor_args.src_length for _ in range(predictor_args.batch_size)]
1612+
benchmark_texts = [
1613+
predictor.tokenizer.pad_token * predictor_args.src_length for _ in range(predictor_args.batch_size)
1614+
]
15921615

15931616
batch_benchmark_texts = batchfy_text(benchmark_texts, predictor_args.batch_size)
15941617
print("***********Start Benchmark**********")

paddlenlp/experimental/transformers/bloom/modeling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from paddle import Tensor, nn
2020
from paddle.distributed import fleet
2121
from paddle.nn.quant import weight_quantize
22-
from paddlenlp_ops import get_padding_offset, get_padding_offset_v2
2322

2423
from paddlenlp.experimental.transformers.fused_transformer_layers import (
2524
FusedBlockMultiTransformer,
@@ -219,6 +218,8 @@ def set_input_embeddings(self, new_embeddings: Tensor):
219218
def remove_padding(self, input_ids, seq_lens_this_time):
220219
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
221220
token_num = paddle.sum(seq_lens_this_time)
221+
from paddlenlp_ops import get_padding_offset
222+
222223
ids_remove_padding, cum_offsets, padding_offset = get_padding_offset(
223224
input_ids, cum_offsets_now, token_num, seq_lens_this_time
224225
)
@@ -592,6 +593,8 @@ def set_transformer_block(self, transformer_config):
592593
def remove_padding(self, input_ids, seq_lens_this_time):
593594
cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time)
594595
token_num = paddle.sum(seq_lens_this_time)
596+
from paddlenlp_ops import get_padding_offset_v2
597+
595598
ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2(
596599
input_ids, cum_offsets_now, token_num, seq_lens_this_time
597600
)

paddlenlp/experimental/transformers/chatglm/modeling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from paddle import nn
1919
from paddle.distributed import fleet
2020
from paddle.nn.quant import weight_quantize
21-
from paddlenlp_ops import get_padding_offset
2221

2322
from paddlenlp.experimental.transformers.fused_transformer_layers import (
2423
FusedMultiTransformerConfig,
@@ -273,6 +272,8 @@ def __init__(self, config: ChatGLMConfig):
273272
def remove_padding(self, input_ids, seq_lens_this_time):
274273
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
275274
token_num = paddle.sum(seq_lens_this_time)
275+
from paddlenlp_ops import get_padding_offset
276+
276277
ids_remove_padding, cum_offsets, padding_offset = get_padding_offset(
277278
input_ids, cum_offsets_now, token_num, seq_lens_this_time
278279
)

paddlenlp/experimental/transformers/chatglm_v2/modeling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import paddle.distributed.fleet as fleet
2020
import paddle.nn as nn
2121
from paddle.nn.quant import weight_quantize
22-
from paddlenlp_ops import get_padding_offset
2322

2423
from paddlenlp.experimental.transformers.fused_transformer_layers import (
2524
FusedMultiTransformerBase,
@@ -202,6 +201,8 @@ def set_input_embeddings(self, value):
202201
def remove_padding(self, input_ids, seq_lens_this_time):
203202
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
204203
token_num = paddle.sum(seq_lens_this_time)
204+
from paddlenlp_ops import get_padding_offset
205+
205206
ids_remove_padding, cum_offsets, padding_offset = get_padding_offset(
206207
input_ids, cum_offsets_now, token_num, seq_lens_this_time
207208
)

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 87 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import paddle
1717
import paddle.distributed as dist
18-
from paddle.framework import LayerHelper, in_dynamic_mode
18+
from paddle.framework import LayerHelper, core, in_dynamic_mode
1919
from paddle.incubate.nn.functional import (
2020
fused_layer_norm,
2121
fused_rms_norm,
@@ -29,23 +29,24 @@
2929
from paddlenlp.utils.import_utils import is_paddlenlp_ops_available
3030
from paddlenlp.utils.log import logger
3131

32-
if is_paddlenlp_ops_available():
32+
if not is_paddlenlp_ops_available():
33+
logger.warning(
34+
"The paddlenlp_ops package is not installed. you can read the docs and install it by hand, "
35+
"you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
36+
)
37+
38+
from paddlenlp_ops import rebuild_padding_v2
39+
40+
if core.is_compiled_with_cuda():
3341
from paddlenlp_ops import (
3442
dequant_int8,
3543
encode_rotary_qk,
3644
qkv_transpose_split,
3745
quant_int8,
3846
rebuild_padding,
39-
rebuild_padding_v2,
4047
transpose_remove_padding,
4148
write_cache_kv,
4249
)
43-
else:
44-
logger.warning(
45-
"The paddlenlp_ops package is not installed. you can read the docs and install it by hand, "
46-
"you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
47-
)
48-
4950

5051
__all__ = [
5152
"FusedMultiTransformerConfig",
@@ -1348,6 +1349,9 @@ def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layer
13481349
class FusedBlockMultiTransformer(FusedMultiTransformerBase):
13491350
def __init__(self, config: FusedMultiTransformerConfig):
13501351
super().__init__(config)
1352+
if not core.is_compiled_with_cuda():
1353+
self.cache_k_per_batch_maxs = paddle.full(shape=[10, 6], fill_value=0, dtype="float32")
1354+
self.cache_v_per_batch_maxs = paddle.full(shape=[10, 6], fill_value=0, dtype="float32")
13511355

13521356
def compute_attn(
13531357
self,
@@ -1375,43 +1379,80 @@ def compute_attn(
13751379
v_quant_scales = self.cache_v_scales
13761380
k_dequant_scales = self.cache_k_out_scales
13771381
v_dequant_scales = self.cache_v_out_scales
1378-
1379-
fmha_out = paddle.incubate.nn.functional.block_multihead_attention(
1380-
qkv_out,
1381-
caches[2 * i],
1382-
caches[2 * i + 1],
1383-
kwargs.get("seq_lens_encoder", None),
1384-
kwargs.get("seq_lens_decoder", None),
1385-
kwargs.get("seq_lens_this_time", None),
1386-
kwargs.get("padding_offsets", None),
1387-
kwargs.get("cum_offsets", None),
1388-
kwargs.get("cu_seqlens_q", None),
1389-
kwargs.get("cu_seqlens_k", None),
1390-
kwargs.get("block_tables", None),
1391-
pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache
1392-
pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache
1393-
k_quant_scales[i] if k_quant_scales is not None else None,
1394-
v_quant_scales[i] if v_quant_scales is not None else None,
1395-
k_dequant_scales[i] if k_dequant_scales is not None else None,
1396-
v_dequant_scales[i] if v_dequant_scales is not None else None,
1397-
None, # qkv_out_scales
1398-
None, # qkv_bias
1399-
None, # out_shifts
1400-
None, # out_smooths
1401-
kwargs.get("max_enc_len_this_time", None),
1402-
kwargs.get("max_dec_len_this_time", None),
1403-
rotary_embs,
1404-
attn_mask,
1405-
kwargs.get("tgt_mask", None),
1406-
kwargs.get("max_input_length", -1),
1407-
kwargs.get("block_size", 64),
1408-
self.use_neox_rotary_style,
1409-
self.config.use_dynamic_cachekv_quant,
1410-
quant_round_type=self.config.quant_round_type,
1411-
quant_max_bound=self.config.quant_max_bound,
1412-
quant_min_bound=self.config.quant_min_bound,
1413-
)[0]
1414-
1382+
if not core.is_compiled_with_cuda():
1383+
fmha_out = paddle.incubate.nn.functional.block_multihead_attention_xpu(
1384+
qkv_out,
1385+
caches[2 * i],
1386+
caches[2 * i + 1],
1387+
kwargs.get("seq_lens_encoder", None),
1388+
kwargs.get("seq_lens_decoder", None),
1389+
kwargs.get("seq_lens_this_time", None),
1390+
kwargs.get("padding_offsets", None),
1391+
kwargs.get("cum_offsets", None),
1392+
kwargs.get("cu_seqlens_q", None),
1393+
kwargs.get("cu_seqlens_k", None),
1394+
kwargs.get("block_tables", None),
1395+
self.cache_k_per_batch_maxs,
1396+
self.cache_v_per_batch_maxs,
1397+
pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache
1398+
pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache
1399+
k_quant_scales[i] if k_quant_scales is not None else None,
1400+
v_quant_scales[i] if v_quant_scales is not None else None,
1401+
k_dequant_scales[i] if k_dequant_scales is not None else None,
1402+
v_dequant_scales[i] if v_dequant_scales is not None else None,
1403+
None, # qkv_out_scales
1404+
None, # qkv_bias
1405+
None, # out_shifts
1406+
None, # out_smooths
1407+
kwargs.get("max_enc_len_this_time", None),
1408+
kwargs.get("max_dec_len_this_time", None),
1409+
rotary_embs,
1410+
attn_mask,
1411+
kwargs.get("tgt_mask", None),
1412+
kwargs.get("max_input_length", -1),
1413+
kwargs.get("block_size", 64),
1414+
self.use_neox_rotary_style,
1415+
self.config.use_dynamic_cachekv_quant,
1416+
quant_round_type=self.config.quant_round_type,
1417+
quant_max_bound=self.config.quant_max_bound,
1418+
quant_min_bound=self.config.quant_min_bound,
1419+
)[0]
1420+
else:
1421+
fmha_out = paddle.incubate.nn.functional.block_multihead_attention(
1422+
qkv_out,
1423+
caches[2 * i],
1424+
caches[2 * i + 1],
1425+
kwargs.get("seq_lens_encoder", None),
1426+
kwargs.get("seq_lens_decoder", None),
1427+
kwargs.get("seq_lens_this_time", None),
1428+
kwargs.get("padding_offsets", None),
1429+
kwargs.get("cum_offsets", None),
1430+
kwargs.get("cu_seqlens_q", None),
1431+
kwargs.get("cu_seqlens_k", None),
1432+
kwargs.get("block_tables", None),
1433+
pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache
1434+
pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache
1435+
k_quant_scales[i] if k_quant_scales is not None else None,
1436+
v_quant_scales[i] if v_quant_scales is not None else None,
1437+
k_dequant_scales[i] if k_dequant_scales is not None else None,
1438+
v_dequant_scales[i] if v_dequant_scales is not None else None,
1439+
None, # qkv_out_scales
1440+
None, # qkv_bias
1441+
None, # out_shifts
1442+
None, # out_smooths
1443+
kwargs.get("max_enc_len_this_time", None),
1444+
kwargs.get("max_dec_len_this_time", None),
1445+
rotary_embs,
1446+
attn_mask,
1447+
kwargs.get("tgt_mask", None),
1448+
kwargs.get("max_input_length", -1),
1449+
kwargs.get("block_size", 64),
1450+
self.use_neox_rotary_style,
1451+
self.config.use_dynamic_cachekv_quant,
1452+
quant_round_type=self.config.quant_round_type,
1453+
quant_max_bound=self.config.quant_max_bound,
1454+
quant_min_bound=self.config.quant_min_bound,
1455+
)[0]
14151456
out_linear_out = self.compute_out_linear(fmha_out, i)
14161457

14171458
return out_linear_out

0 commit comments

Comments
 (0)