|
| 1 | +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import os |
| 16 | + |
| 17 | +import paddle |
| 18 | +import paddle.nn.functional as F |
| 19 | + |
| 20 | +try: |
| 21 | + from paddle.incubate.nn.functional import fused_rotary_position_embedding |
| 22 | +except ImportError: |
| 23 | + fused_rotary_position_embedding = None |
| 24 | + |
| 25 | +try: |
| 26 | + from paddle.incubate.nn.functional import swiglu |
| 27 | +except ImportError: |
| 28 | + |
| 29 | + def swiglu(x, y=None): |
| 30 | + if y is None: |
| 31 | + x, y = paddle.chunk(x, chunks=2, axis=-1) |
| 32 | + return F.silu(x) * y |
| 33 | + |
| 34 | + |
| 35 | +from paddle.utils import try_import |
| 36 | + |
| 37 | +from paddlenlp.utils.tools import get_env_device |
| 38 | + |
| 39 | +try: |
| 40 | + from paddle.incubate.nn.functional import fused_rotary_position_embedding |
| 41 | +except ImportError: |
| 42 | + fused_rotary_position_embedding = None |
| 43 | +try: |
| 44 | + if get_env_device() == "npu": |
| 45 | + from paddle.base import core |
| 46 | + |
| 47 | + for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")): |
| 48 | + if lib.endswith(".so"): |
| 49 | + paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib) |
| 50 | + from paddle.nn.functional.flash_attention import flash_attention |
| 51 | +except: |
| 52 | + flash_attention = None |
| 53 | + |
| 54 | + |
| 55 | +def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb): |
| 56 | + assert past_key_value is None, "fuse rotary not support cache kv for now" |
| 57 | + batch_size, seq_length, num_heads, head_dim = query_states.shape |
| 58 | + _, kv_seq_len, num_key_value_heads, _ = key_states.shape |
| 59 | + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) |
| 60 | + if get_env_device() == "npu": |
| 61 | + query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0] |
| 62 | + key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0] |
| 63 | + else: |
| 64 | + # paddle version > 2.6 or develop support q and k/v with different num_heads |
| 65 | + paddle_version = float(paddle.__version__[:3]) |
| 66 | + if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads): |
| 67 | + query_states, _, _ = fused_rotary_position_embedding( |
| 68 | + query_states, |
| 69 | + None, |
| 70 | + None, |
| 71 | + sin=sin, |
| 72 | + cos=cos, |
| 73 | + position_ids=position_ids, |
| 74 | + use_neox_rotary_style=False, |
| 75 | + ) |
| 76 | + key_states, _, _ = fused_rotary_position_embedding( |
| 77 | + key_states, |
| 78 | + None, |
| 79 | + None, |
| 80 | + sin=sin, |
| 81 | + cos=cos, |
| 82 | + position_ids=position_ids, |
| 83 | + use_neox_rotary_style=False, |
| 84 | + ) |
| 85 | + else: |
| 86 | + query_states, key_states, _ = fused_rotary_position_embedding( |
| 87 | + query_states, |
| 88 | + key_states, |
| 89 | + v=None, |
| 90 | + sin=sin, |
| 91 | + cos=cos, |
| 92 | + position_ids=position_ids, |
| 93 | + use_neox_rotary_style=False, |
| 94 | + ) |
| 95 | + return query_states, key_states |
| 96 | + |
| 97 | + |
| 98 | +def rms_norm_fused(x_in, w, eps): |
| 99 | + fused_ln = try_import("fused_ln") |
| 100 | + return fused_ln.fused_rms_norm(x_in, w, eps)[0] |
| 101 | + |
| 102 | + |
| 103 | +def fusion_rms_norm(hidden_states, weight, variance_epsilon): |
| 104 | + if get_env_device() == "npu": |
| 105 | + return core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0] |
| 106 | + elif get_env_device() == "xpu": |
| 107 | + try: |
| 108 | + import paddle_xpu_nn # noqa: F821 |
| 109 | + |
| 110 | + return paddle_xpu_nn.xpu_rms_norm(hidden_states, weight, variance_epsilon)[0] |
| 111 | + except ImportError: |
| 112 | + raise NotImplementedError( |
| 113 | + f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature" |
| 114 | + ) |
| 115 | + return rms_norm_fused(hidden_states, weight, variance_epsilon) |
| 116 | + |
| 117 | + |
| 118 | +def fusion_flash_attention( |
| 119 | + query_states, |
| 120 | + config, |
| 121 | + key_states, |
| 122 | + value_states, |
| 123 | + attention_mask, |
| 124 | + output_attentions, |
| 125 | + alibi=None, |
| 126 | + sequence_parallel=False, |
| 127 | + reshard_layer=None, |
| 128 | + npu_is_casual=False, |
| 129 | +): |
| 130 | + bsz, q_len, num_heads, head_dim = query_states.shape |
| 131 | + _, kv_seq_len, _, _ = value_states.shape |
| 132 | + version = paddle.version.full_version |
| 133 | + if version != "0.0.0" and version <= "2.5.2": |
| 134 | + if alibi is not None: |
| 135 | + raise ValueError("Flash Attention doesn't support alibi") |
| 136 | + attn_output, attn_weights = flash_attention( |
| 137 | + query_states, |
| 138 | + key_states, |
| 139 | + value_states, |
| 140 | + causal=True, |
| 141 | + return_softmax=output_attentions, |
| 142 | + ) |
| 143 | + else: |
| 144 | + if alibi is not None: |
| 145 | + alibi = alibi.reshape([bsz, num_heads, 1, -1]) |
| 146 | + attention_mask = attention_mask.cast(alibi.dtype) + alibi |
| 147 | + if get_env_device() == "npu": |
| 148 | + attn_output = core.eager._run_custom_op( |
| 149 | + "flash_attention_npu", |
| 150 | + query_states, |
| 151 | + key_states, |
| 152 | + value_states, |
| 153 | + None, |
| 154 | + attention_mask, |
| 155 | + 0.0, |
| 156 | + attention_mask is None, |
| 157 | + True, |
| 158 | + False, |
| 159 | + npu_is_casual, |
| 160 | + )[0] |
| 161 | + else: |
| 162 | + attn_output = F.scaled_dot_product_attention( |
| 163 | + query_states, |
| 164 | + key_states, |
| 165 | + value_states, |
| 166 | + attn_mask=attention_mask, |
| 167 | + is_causal=attention_mask is None, |
| 168 | + ) |
| 169 | + attn_weights = None |
| 170 | + |
| 171 | + if reshard_layer is not None: |
| 172 | + # attn_output shape: [bs, seqlen, num_head/sep, head_dim] |
| 173 | + attn_output = reshard_layer( |
| 174 | + attn_output, |
| 175 | + split_axis=1, |
| 176 | + concat_axis=2, |
| 177 | + ) |
| 178 | + # attn_output shape: [bs, seqlen/sep, num_head, head_dim] |
| 179 | + assert ( |
| 180 | + config.sep_parallel_degree > 1 and q_len % config.sep_parallel_degree == 0 |
| 181 | + ), f"q_len:{q_len}, config.sep_parallel_degree:{config.sep_parallel_degree}" |
| 182 | + q_len = q_len // config.sep_parallel_degree |
| 183 | + num_heads = num_heads * config.sep_parallel_degree |
| 184 | + |
| 185 | + if sequence_parallel: |
| 186 | + attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads]) |
| 187 | + else: |
| 188 | + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) |
| 189 | + return (attn_output, attn_weights) if output_attentions else attn_output |
0 commit comments