|
15 | 15 |
|
16 | 16 | import paddle
|
17 | 17 | import paddle.distributed as dist
|
18 |
| -from paddle.framework import LayerHelper, in_dynamic_mode |
| 18 | +from paddle.framework import LayerHelper, core, in_dynamic_mode |
19 | 19 | from paddle.incubate.nn.functional import (
|
20 | 20 | fused_layer_norm,
|
21 | 21 | fused_rms_norm,
|
|
29 | 29 | from paddlenlp.utils.import_utils import is_paddlenlp_ops_available
|
30 | 30 | from paddlenlp.utils.log import logger
|
31 | 31 |
|
32 |
| -if is_paddlenlp_ops_available(): |
| 32 | +if not is_paddlenlp_ops_available(): |
| 33 | + logger.warning( |
| 34 | + "The paddlenlp_ops package is not installed. you can read the docs and install it by hand, " |
| 35 | + "you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md" |
| 36 | + ) |
| 37 | + |
| 38 | +from paddlenlp_ops import rebuild_padding_v2 |
| 39 | + |
| 40 | +if core.is_compiled_with_cuda(): |
33 | 41 | from paddlenlp_ops import (
|
34 | 42 | dequant_int8,
|
35 | 43 | encode_rotary_qk,
|
36 | 44 | qkv_transpose_split,
|
37 | 45 | quant_int8,
|
38 | 46 | rebuild_padding,
|
39 |
| - rebuild_padding_v2, |
40 | 47 | transpose_remove_padding,
|
41 | 48 | write_cache_kv,
|
42 | 49 | )
|
43 |
| -else: |
44 |
| - logger.warning( |
45 |
| - "The paddlenlp_ops package is not installed. you can read the docs and install it by hand, " |
46 |
| - "you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md" |
47 |
| - ) |
48 |
| - |
49 | 50 |
|
50 | 51 | __all__ = [
|
51 | 52 | "FusedMultiTransformerConfig",
|
@@ -1348,6 +1349,9 @@ def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layer
|
1348 | 1349 | class FusedBlockMultiTransformer(FusedMultiTransformerBase):
|
1349 | 1350 | def __init__(self, config: FusedMultiTransformerConfig):
|
1350 | 1351 | super().__init__(config)
|
| 1352 | + if not core.is_compiled_with_cuda(): |
| 1353 | + self.cache_k_per_batch_maxs = paddle.full(shape=[10, 6], fill_value=0, dtype="float32") |
| 1354 | + self.cache_v_per_batch_maxs = paddle.full(shape=[10, 6], fill_value=0, dtype="float32") |
1351 | 1355 |
|
1352 | 1356 | def compute_attn(
|
1353 | 1357 | self,
|
@@ -1375,43 +1379,80 @@ def compute_attn(
|
1375 | 1379 | v_quant_scales = self.cache_v_scales
|
1376 | 1380 | k_dequant_scales = self.cache_k_out_scales
|
1377 | 1381 | v_dequant_scales = self.cache_v_out_scales
|
1378 |
| - |
1379 |
| - fmha_out = paddle.incubate.nn.functional.block_multihead_attention( |
1380 |
| - qkv_out, |
1381 |
| - caches[2 * i], |
1382 |
| - caches[2 * i + 1], |
1383 |
| - kwargs.get("seq_lens_encoder", None), |
1384 |
| - kwargs.get("seq_lens_decoder", None), |
1385 |
| - kwargs.get("seq_lens_this_time", None), |
1386 |
| - kwargs.get("padding_offsets", None), |
1387 |
| - kwargs.get("cum_offsets", None), |
1388 |
| - kwargs.get("cu_seqlens_q", None), |
1389 |
| - kwargs.get("cu_seqlens_k", None), |
1390 |
| - kwargs.get("block_tables", None), |
1391 |
| - pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache |
1392 |
| - pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache |
1393 |
| - k_quant_scales[i] if k_quant_scales is not None else None, |
1394 |
| - v_quant_scales[i] if v_quant_scales is not None else None, |
1395 |
| - k_dequant_scales[i] if k_dequant_scales is not None else None, |
1396 |
| - v_dequant_scales[i] if v_dequant_scales is not None else None, |
1397 |
| - None, # qkv_out_scales |
1398 |
| - None, # qkv_bias |
1399 |
| - None, # out_shifts |
1400 |
| - None, # out_smooths |
1401 |
| - kwargs.get("max_enc_len_this_time", None), |
1402 |
| - kwargs.get("max_dec_len_this_time", None), |
1403 |
| - rotary_embs, |
1404 |
| - attn_mask, |
1405 |
| - kwargs.get("tgt_mask", None), |
1406 |
| - kwargs.get("max_input_length", -1), |
1407 |
| - kwargs.get("block_size", 64), |
1408 |
| - self.use_neox_rotary_style, |
1409 |
| - self.config.use_dynamic_cachekv_quant, |
1410 |
| - quant_round_type=self.config.quant_round_type, |
1411 |
| - quant_max_bound=self.config.quant_max_bound, |
1412 |
| - quant_min_bound=self.config.quant_min_bound, |
1413 |
| - )[0] |
1414 |
| - |
| 1382 | + if not core.is_compiled_with_cuda(): |
| 1383 | + fmha_out = paddle.incubate.nn.functional.block_multihead_attention_xpu( |
| 1384 | + qkv_out, |
| 1385 | + caches[2 * i], |
| 1386 | + caches[2 * i + 1], |
| 1387 | + kwargs.get("seq_lens_encoder", None), |
| 1388 | + kwargs.get("seq_lens_decoder", None), |
| 1389 | + kwargs.get("seq_lens_this_time", None), |
| 1390 | + kwargs.get("padding_offsets", None), |
| 1391 | + kwargs.get("cum_offsets", None), |
| 1392 | + kwargs.get("cu_seqlens_q", None), |
| 1393 | + kwargs.get("cu_seqlens_k", None), |
| 1394 | + kwargs.get("block_tables", None), |
| 1395 | + self.cache_k_per_batch_maxs, |
| 1396 | + self.cache_v_per_batch_maxs, |
| 1397 | + pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache |
| 1398 | + pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache |
| 1399 | + k_quant_scales[i] if k_quant_scales is not None else None, |
| 1400 | + v_quant_scales[i] if v_quant_scales is not None else None, |
| 1401 | + k_dequant_scales[i] if k_dequant_scales is not None else None, |
| 1402 | + v_dequant_scales[i] if v_dequant_scales is not None else None, |
| 1403 | + None, # qkv_out_scales |
| 1404 | + None, # qkv_bias |
| 1405 | + None, # out_shifts |
| 1406 | + None, # out_smooths |
| 1407 | + kwargs.get("max_enc_len_this_time", None), |
| 1408 | + kwargs.get("max_dec_len_this_time", None), |
| 1409 | + rotary_embs, |
| 1410 | + attn_mask, |
| 1411 | + kwargs.get("tgt_mask", None), |
| 1412 | + kwargs.get("max_input_length", -1), |
| 1413 | + kwargs.get("block_size", 64), |
| 1414 | + self.use_neox_rotary_style, |
| 1415 | + self.config.use_dynamic_cachekv_quant, |
| 1416 | + quant_round_type=self.config.quant_round_type, |
| 1417 | + quant_max_bound=self.config.quant_max_bound, |
| 1418 | + quant_min_bound=self.config.quant_min_bound, |
| 1419 | + )[0] |
| 1420 | + else: |
| 1421 | + fmha_out = paddle.incubate.nn.functional.block_multihead_attention( |
| 1422 | + qkv_out, |
| 1423 | + caches[2 * i], |
| 1424 | + caches[2 * i + 1], |
| 1425 | + kwargs.get("seq_lens_encoder", None), |
| 1426 | + kwargs.get("seq_lens_decoder", None), |
| 1427 | + kwargs.get("seq_lens_this_time", None), |
| 1428 | + kwargs.get("padding_offsets", None), |
| 1429 | + kwargs.get("cum_offsets", None), |
| 1430 | + kwargs.get("cu_seqlens_q", None), |
| 1431 | + kwargs.get("cu_seqlens_k", None), |
| 1432 | + kwargs.get("block_tables", None), |
| 1433 | + pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache |
| 1434 | + pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache |
| 1435 | + k_quant_scales[i] if k_quant_scales is not None else None, |
| 1436 | + v_quant_scales[i] if v_quant_scales is not None else None, |
| 1437 | + k_dequant_scales[i] if k_dequant_scales is not None else None, |
| 1438 | + v_dequant_scales[i] if v_dequant_scales is not None else None, |
| 1439 | + None, # qkv_out_scales |
| 1440 | + None, # qkv_bias |
| 1441 | + None, # out_shifts |
| 1442 | + None, # out_smooths |
| 1443 | + kwargs.get("max_enc_len_this_time", None), |
| 1444 | + kwargs.get("max_dec_len_this_time", None), |
| 1445 | + rotary_embs, |
| 1446 | + attn_mask, |
| 1447 | + kwargs.get("tgt_mask", None), |
| 1448 | + kwargs.get("max_input_length", -1), |
| 1449 | + kwargs.get("block_size", 64), |
| 1450 | + self.use_neox_rotary_style, |
| 1451 | + self.config.use_dynamic_cachekv_quant, |
| 1452 | + quant_round_type=self.config.quant_round_type, |
| 1453 | + quant_max_bound=self.config.quant_max_bound, |
| 1454 | + quant_min_bound=self.config.quant_min_bound, |
| 1455 | + )[0] |
1415 | 1456 | out_linear_out = self.compute_out_linear(fmha_out, i)
|
1416 | 1457 |
|
1417 | 1458 | return out_linear_out
|
|
0 commit comments