Skip to content

Commit ed186fc

Browse files
committed
xpu devices support llama-7b basic mode inference (turn on BlockAttention)
1 parent 4609d07 commit ed186fc

File tree

11 files changed

+134
-69
lines changed

11 files changed

+134
-69
lines changed

llm/docs/inference.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ PaddleNLP 针对于Transformer 系列编写了高性能自定义算子,提升
8383

8484
```shell
8585
git clone https://github.com/PaddlePaddle/PaddleNLP
86+
#GPU设备安装自定义算子
8687
cd ./paddlenlp/csrc && python setup_cuda.py install
88+
#XPU设备安装自定义算子
89+
cd ./paddlenlp/csrc/xpu/src && sh cmake_build.sh
8790
```
8891

8992
### 2.3 关闭BlockAttention的高性能推理
@@ -163,6 +166,9 @@ python predictor.py --model_name_or_path ./inference --inference_model --quant_
163166
# 动态图模型推理命令参考
164167
python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --block_attn
165168

169+
# XPU设备动态图模型推理命令参考
170+
python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --block_attn --device xpu
171+
166172
# Weight Only Int8 动态图推理参考
167173
python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --quant_type weight_only_int8 --block_attn
168174

@@ -179,6 +185,9 @@ python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_
179185
# 动转静命令参考
180186
python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --block_attn
181187

188+
# XPU设备动转静命令参考
189+
python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --block_attn --device xpu
190+
182191
# Weight Only Int8 动转静命令参考
183192
python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --quant_type weight_only_int8 --block_attn
184193

@@ -194,6 +203,9 @@ python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --infere
194203
# 静态图推理命令参考
195204
python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --block_attn
196205

206+
# XPU设备静态图推理命令参考
207+
python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --block_attn --device xpu
208+
197209
# Weight Only Int8 静态图推理命令参考
198210
python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --quant_type weight_only_int8 --block_attn
199211

llm/predictor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,11 @@ def _create_predictor(self, predictor_args: PredictorArgument):
650650
if predictor_args.device in paddle.device.get_all_custom_device_type():
651651
device_id = int(os.environ.get("FLAGS_selected_{}s".format(predictor_args.device), 0))
652652
config.enable_custom_device(predictor_args.device, device_id)
653+
elif predictor_args.device == "xpu":
654+
raise ValueError(
655+
"you should export xpu static model with --block_attn flag and use predictor with --block_attn too"
656+
"https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/inference.md"
657+
)
653658
else:
654659
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
655660
config.enable_use_gpu(100, device_id)
@@ -1076,6 +1081,16 @@ def _create_predictor(self, predictor_args: PredictorArgument):
10761081
if predictor_args.device in paddle.device.get_all_custom_device_type():
10771082
device_id = int(os.environ.get("FLAGS_selected_{}s".format(predictor_args.device), 0))
10781083
config.enable_custom_device(predictor_args.device, device_id)
1084+
elif predictor_args.device == "xpu":
1085+
config.enable_xpu()
1086+
device_id = int(os.environ.get("FLAGS_selected_xpus", 0))
1087+
config.set_xpu_device_id(device_id)
1088+
xpu_config = paddle.inference.XpuConfig()
1089+
xpu_config.device_id = device_id
1090+
xpu_config.l3_size = 63*1024*1024
1091+
xpu_config.l3_autotune_size = 63*1024*1024
1092+
config.set_xpu_config(xpu_config)
1093+
config.enable_new_executor()
10791094
else:
10801095
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
10811096
config.enable_use_gpu(100, device_id)

paddlenlp/experimental/transformers/bloom/modeling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from paddle import Tensor, nn
2020
from paddle.distributed import fleet
2121
from paddle.nn.quant import weight_quantize
22-
from paddlenlp_ops import get_padding_offset, get_padding_offset_v2
2322

2423
from paddlenlp.experimental.transformers.fused_transformer_layers import (
2524
FusedBlockMultiTransformer,
@@ -219,6 +218,7 @@ def set_input_embeddings(self, new_embeddings: Tensor):
219218
def remove_padding(self, input_ids, seq_lens_this_time):
220219
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
221220
token_num = paddle.sum(seq_lens_this_time)
221+
from paddlenlp_ops import get_padding_offset
222222
ids_remove_padding, cum_offsets, padding_offset = get_padding_offset(
223223
input_ids, cum_offsets_now, token_num, seq_lens_this_time
224224
)
@@ -592,6 +592,7 @@ def set_transformer_block(self, transformer_config):
592592
def remove_padding(self, input_ids, seq_lens_this_time):
593593
cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time)
594594
token_num = paddle.sum(seq_lens_this_time)
595+
from paddlenlp_ops import get_padding_offset_v2
595596
ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2(
596597
input_ids, cum_offsets_now, token_num, seq_lens_this_time
597598
)

paddlenlp/experimental/transformers/chatglm/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from paddle import nn
1919
from paddle.distributed import fleet
2020
from paddle.nn.quant import weight_quantize
21-
from paddlenlp_ops import get_padding_offset
2221

2322
from paddlenlp.experimental.transformers.fused_transformer_layers import (
2423
FusedMultiTransformerConfig,
@@ -273,6 +272,7 @@ def __init__(self, config: ChatGLMConfig):
273272
def remove_padding(self, input_ids, seq_lens_this_time):
274273
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
275274
token_num = paddle.sum(seq_lens_this_time)
275+
from paddlenlp_ops import get_padding_offset
276276
ids_remove_padding, cum_offsets, padding_offset = get_padding_offset(
277277
input_ids, cum_offsets_now, token_num, seq_lens_this_time
278278
)

paddlenlp/experimental/transformers/chatglm_v2/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import paddle.distributed.fleet as fleet
2020
import paddle.nn as nn
2121
from paddle.nn.quant import weight_quantize
22-
from paddlenlp_ops import get_padding_offset
2322

2423
from paddlenlp.experimental.transformers.fused_transformer_layers import (
2524
FusedMultiTransformerBase,
@@ -202,6 +201,7 @@ def set_input_embeddings(self, value):
202201
def remove_padding(self, input_ids, seq_lens_this_time):
203202
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
204203
token_num = paddle.sum(seq_lens_this_time)
204+
from paddlenlp_ops import get_padding_offset
205205
ids_remove_padding, cum_offsets, padding_offset = get_padding_offset(
206206
input_ids, cum_offsets_now, token_num, seq_lens_this_time
207207
)

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 87 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import paddle
1717
import paddle.distributed as dist
18-
from paddle.framework import LayerHelper, in_dynamic_mode
18+
from paddle.framework import LayerHelper, in_dynamic_mode, core
1919
from paddle.incubate.nn.functional import (
2020
fused_layer_norm,
2121
fused_rms_norm,
@@ -28,24 +28,25 @@
2828

2929
from paddlenlp.utils.import_utils import is_paddlenlp_ops_available
3030
from paddlenlp.utils.log import logger
31+
from paddlenlp_ops import rebuild_padding_v2
3132

32-
if is_paddlenlp_ops_available():
33+
34+
if not is_paddlenlp_ops_available():
35+
logger.warning(
36+
"The paddlenlp_ops package is not installed. you can read the docs and install it by hand, "
37+
"you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
38+
)
39+
40+
if core.is_compiled_with_cuda():
3341
from paddlenlp_ops import (
3442
dequant_int8,
3543
encode_rotary_qk,
3644
qkv_transpose_split,
3745
quant_int8,
3846
rebuild_padding,
39-
rebuild_padding_v2,
4047
transpose_remove_padding,
4148
write_cache_kv,
4249
)
43-
else:
44-
logger.warning(
45-
"The paddlenlp_ops package is not installed. you can read the docs and install it by hand, "
46-
"you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
47-
)
48-
4950

5051
__all__ = [
5152
"FusedMultiTransformerConfig",
@@ -1348,6 +1349,9 @@ def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layer
13481349
class FusedBlockMultiTransformer(FusedMultiTransformerBase):
13491350
def __init__(self, config: FusedMultiTransformerConfig):
13501351
super().__init__(config)
1352+
if not core.is_compiled_with_cuda():
1353+
self.cache_k_per_batch_maxs = paddle.full(shape=[10, 6], fill_value=0, dtype='float32')
1354+
self.cache_v_per_batch_maxs = paddle.full(shape=[10, 6], fill_value=0, dtype='float32')
13511355

13521356
def compute_attn(
13531357
self,
@@ -1375,43 +1379,80 @@ def compute_attn(
13751379
v_quant_scales = self.cache_v_scales
13761380
k_dequant_scales = self.cache_k_out_scales
13771381
v_dequant_scales = self.cache_v_out_scales
1378-
1379-
fmha_out = paddle.incubate.nn.functional.block_multihead_attention(
1380-
qkv_out,
1381-
caches[2 * i],
1382-
caches[2 * i + 1],
1383-
kwargs.get("seq_lens_encoder", None),
1384-
kwargs.get("seq_lens_decoder", None),
1385-
kwargs.get("seq_lens_this_time", None),
1386-
kwargs.get("padding_offsets", None),
1387-
kwargs.get("cum_offsets", None),
1388-
kwargs.get("cu_seqlens_q", None),
1389-
kwargs.get("cu_seqlens_k", None),
1390-
kwargs.get("block_tables", None),
1391-
pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache
1392-
pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache
1393-
k_quant_scales[i] if k_quant_scales is not None else None,
1394-
v_quant_scales[i] if v_quant_scales is not None else None,
1395-
k_dequant_scales[i] if k_dequant_scales is not None else None,
1396-
v_dequant_scales[i] if v_dequant_scales is not None else None,
1397-
None, # qkv_out_scales
1398-
None, # qkv_bias
1399-
None, # out_shifts
1400-
None, # out_smooths
1401-
kwargs.get("max_enc_len_this_time", None),
1402-
kwargs.get("max_dec_len_this_time", None),
1403-
rotary_embs,
1404-
attn_mask,
1405-
kwargs.get("tgt_mask", None),
1406-
kwargs.get("max_input_length", -1),
1407-
kwargs.get("block_size", 64),
1408-
self.use_neox_rotary_style,
1409-
self.config.use_dynamic_cachekv_quant,
1410-
quant_round_type=self.config.quant_round_type,
1411-
quant_max_bound=self.config.quant_max_bound,
1412-
quant_min_bound=self.config.quant_min_bound,
1413-
)[0]
1414-
1382+
if not core.is_compiled_with_cuda():
1383+
fmha_out = paddle.incubate.nn.functional.block_multihead_attention_xpu(
1384+
qkv_out,
1385+
caches[2 * i],
1386+
caches[2 * i + 1],
1387+
kwargs.get("seq_lens_encoder", None),
1388+
kwargs.get("seq_lens_decoder", None),
1389+
kwargs.get("seq_lens_this_time", None),
1390+
kwargs.get("padding_offsets", None),
1391+
kwargs.get("cum_offsets", None),
1392+
kwargs.get("cu_seqlens_q", None),
1393+
kwargs.get("cu_seqlens_k", None),
1394+
kwargs.get("block_tables", None),
1395+
self.cache_k_per_batch_maxs,
1396+
self.cache_v_per_batch_maxs,
1397+
pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache
1398+
pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache
1399+
k_quant_scales[i] if k_quant_scales is not None else None,
1400+
v_quant_scales[i] if v_quant_scales is not None else None,
1401+
k_dequant_scales[i] if k_dequant_scales is not None else None,
1402+
v_dequant_scales[i] if v_dequant_scales is not None else None,
1403+
None, # qkv_out_scales
1404+
None, # qkv_bias
1405+
None, # out_shifts
1406+
None, # out_smooths
1407+
kwargs.get("max_enc_len_this_time", None),
1408+
kwargs.get("max_dec_len_this_time", None),
1409+
rotary_embs,
1410+
attn_mask,
1411+
kwargs.get("tgt_mask", None),
1412+
kwargs.get("max_input_length", -1),
1413+
kwargs.get("block_size", 64),
1414+
self.use_neox_rotary_style,
1415+
self.config.use_dynamic_cachekv_quant,
1416+
quant_round_type=self.config.quant_round_type,
1417+
quant_max_bound=self.config.quant_max_bound,
1418+
quant_min_bound=self.config.quant_min_bound,
1419+
)[0]
1420+
else:
1421+
fmha_out = paddle.incubate.nn.functional.block_multihead_attention(
1422+
qkv_out,
1423+
caches[2 * i],
1424+
caches[2 * i + 1],
1425+
kwargs.get("seq_lens_encoder", None),
1426+
kwargs.get("seq_lens_decoder", None),
1427+
kwargs.get("seq_lens_this_time", None),
1428+
kwargs.get("padding_offsets", None),
1429+
kwargs.get("cum_offsets", None),
1430+
kwargs.get("cu_seqlens_q", None),
1431+
kwargs.get("cu_seqlens_k", None),
1432+
kwargs.get("block_tables", None),
1433+
pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache
1434+
pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache
1435+
k_quant_scales[i] if k_quant_scales is not None else None,
1436+
v_quant_scales[i] if v_quant_scales is not None else None,
1437+
k_dequant_scales[i] if k_dequant_scales is not None else None,
1438+
v_dequant_scales[i] if v_dequant_scales is not None else None,
1439+
None, # qkv_out_scales
1440+
None, # qkv_bias
1441+
None, # out_shifts
1442+
None, # out_smooths
1443+
kwargs.get("max_enc_len_this_time", None),
1444+
kwargs.get("max_dec_len_this_time", None),
1445+
rotary_embs,
1446+
attn_mask,
1447+
kwargs.get("tgt_mask", None),
1448+
kwargs.get("max_input_length", -1),
1449+
kwargs.get("block_size", 64),
1450+
self.use_neox_rotary_style,
1451+
self.config.use_dynamic_cachekv_quant,
1452+
quant_round_type=self.config.quant_round_type,
1453+
quant_max_bound=self.config.quant_max_bound,
1454+
quant_min_bound=self.config.quant_min_bound,
1455+
)[0]
14151456
out_linear_out = self.compute_out_linear(fmha_out, i)
14161457

14171458
return out_linear_out

paddlenlp/experimental/transformers/generation_utils.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,6 @@
1717

1818
import paddle
1919
import paddle.nn.functional as F
20-
from paddlenlp_ops import (
21-
get_token_penalty_multi_scores,
22-
get_token_penalty_multi_scores_v2,
23-
save_output,
24-
save_with_output,
25-
set_stop_value_multi_ends,
26-
set_stop_value_multi_ends_v2,
27-
set_value_by_flags_and_idx,
28-
set_value_by_flags_and_idx_v2,
29-
update_inputs,
30-
)
3120

3221
from paddlenlp.generation import GenerationMixin, LogitsProcessor, LogitsProcessorList
3322

@@ -208,6 +197,7 @@ def update_model_kwargs_for_generation(self, cache, just_decoder, next_tokens, e
208197
model_kwargs["stop_flags"] = paddle.logical_or(model_kwargs["stop_flags"], length_cond)
209198
if cache is None:
210199
next_tokens = paddle.where(just_decoder, paddle.full_like(next_tokens, -1), next_tokens)
200+
from paddlenlp_ops import set_stop_value_multi_ends
211201
next_tokens, model_kwargs["stop_flags"] = set_stop_value_multi_ends(
212202
next_tokens, model_kwargs["stop_flags"], eos_token_id, 2
213203
) # multi ends
@@ -305,6 +295,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
305295
) # not update when continue decode
306296
else:
307297
step_idx = model_kwargs["step_idx"]
298+
from paddlenlp_ops import set_value_by_flags_and_idx
308299
model_kwargs["stop_flags"] = set_value_by_flags_and_idx(
309300
model_kwargs["pre_ids"],
310301
model_kwargs["tgt_ids"],
@@ -316,6 +307,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
316307
logits = paddle.cast(logits, paddle.float32)
317308
logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori)
318309

310+
from paddlenlp_ops import get_token_penalty_multi_scores
319311
logits = get_token_penalty_multi_scores(
320312
model_kwargs["pre_ids"],
321313
logits,
@@ -347,6 +339,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
347339
else:
348340
model_kwargs["all_input_ids"] = paddle.concat([model_kwargs["all_input_ids"], next_tokens], axis=1)
349341

342+
from paddlenlp_ops import save_with_output
350343
save_with_output(
351344
next_tokens,
352345
batch_idx,
@@ -635,6 +628,7 @@ def _post_process_(
635628
model_kwargs,
636629
):
637630
step_idx = model_kwargs["step_idx"]
631+
from paddlenlp_ops import set_value_by_flags_and_idx_v2
638632
set_value_by_flags_and_idx_v2(
639633
model_kwargs["pre_ids"],
640634
model_kwargs["input_ids"],
@@ -648,6 +642,7 @@ def _post_process_(
648642
logits = paddle.cast(outputs, paddle.float32)
649643

650644
# pre-process distribution
645+
from paddlenlp_ops import get_token_penalty_multi_scores_v2
651646
logits = get_token_penalty_multi_scores_v2(
652647
model_kwargs["pre_ids"],
653648
logits,
@@ -673,11 +668,13 @@ def _post_process_(
673668
paddle.assign(step_idx, model_kwargs["step_idx"])
674669
length_cond = paddle.greater_equal(step_idx, model_kwargs["max_dec_len"])
675670
stop_flags = paddle.logical_or(model_kwargs["stop_flags"], length_cond)
671+
from paddlenlp_ops import set_stop_value_multi_ends_v2
676672
set_stop_value_multi_ends_v2(
677673
next_tokens, stop_flags, model_kwargs["seq_lens_this_time"], eos_token_id, model_kwargs["next_tokens"]
678674
) # multi ends
679675
paddle.assign(stop_flags, model_kwargs["stop_flags"])
680676
# update inputs
677+
from paddlenlp_ops import update_inputs
681678
update_inputs(
682679
stop_flags,
683680
model_kwargs["not_need_stop"],
@@ -689,6 +686,7 @@ def _post_process_(
689686
next_tokens,
690687
model_kwargs["is_block_step"],
691688
)
689+
from paddlenlp_ops import save_output
692690
save_output(next_tokens, model_kwargs["not_need_stop"], self.config.tensor_parallel_rank)
693691
return next_tokens
694692

0 commit comments

Comments
 (0)