Skip to content

Commit df78b71

Browse files
committed
update
1 parent 9a2f1c5 commit df78b71

File tree

2 files changed

+190
-1
lines changed

2 files changed

+190
-1
lines changed
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import paddle
18+
import paddle.nn.functional as F
19+
20+
try:
21+
from paddle.incubate.nn.functional import fused_rotary_position_embedding
22+
except ImportError:
23+
fused_rotary_position_embedding = None
24+
25+
try:
26+
from paddle.incubate.nn.functional import swiglu
27+
except ImportError:
28+
29+
def swiglu(x, y=None):
30+
if y is None:
31+
x, y = paddle.chunk(x, chunks=2, axis=-1)
32+
return F.silu(x) * y
33+
34+
35+
from paddle.utils import try_import
36+
37+
from paddlenlp.utils.tools import get_env_device
38+
39+
try:
40+
from paddle.incubate.nn.functional import fused_rotary_position_embedding
41+
except ImportError:
42+
fused_rotary_position_embedding = None
43+
try:
44+
if get_env_device() == "npu":
45+
from paddle.base import core
46+
47+
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
48+
if lib.endswith(".so"):
49+
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib)
50+
from paddle.nn.functional.flash_attention import flash_attention
51+
except:
52+
flash_attention = None
53+
54+
55+
def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb):
56+
assert past_key_value is None, "fuse rotary not support cache kv for now"
57+
batch_size, seq_length, num_heads, head_dim = query_states.shape
58+
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
59+
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
60+
if get_env_device() == "npu":
61+
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
62+
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
63+
else:
64+
# paddle version > 2.6 or develop support q and k/v with different num_heads
65+
paddle_version = float(paddle.__version__[:3])
66+
if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads):
67+
query_states, _, _ = fused_rotary_position_embedding(
68+
query_states,
69+
None,
70+
None,
71+
sin=sin,
72+
cos=cos,
73+
position_ids=position_ids,
74+
use_neox_rotary_style=False,
75+
)
76+
key_states, _, _ = fused_rotary_position_embedding(
77+
key_states,
78+
None,
79+
None,
80+
sin=sin,
81+
cos=cos,
82+
position_ids=position_ids,
83+
use_neox_rotary_style=False,
84+
)
85+
else:
86+
query_states, key_states, _ = fused_rotary_position_embedding(
87+
query_states,
88+
key_states,
89+
v=None,
90+
sin=sin,
91+
cos=cos,
92+
position_ids=position_ids,
93+
use_neox_rotary_style=False,
94+
)
95+
return query_states, key_states
96+
97+
98+
def rms_norm_fused(x_in, w, eps):
99+
fused_ln = try_import("fused_ln")
100+
return fused_ln.fused_rms_norm(x_in, w, eps)[0]
101+
102+
103+
def fusion_rms_norm(hidden_states, weight, variance_epsilon):
104+
if get_env_device() == "npu":
105+
return core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0]
106+
elif get_env_device() == "xpu":
107+
try:
108+
import paddle_xpu_nn # noqa: F821
109+
110+
return paddle_xpu_nn.xpu_rms_norm(hidden_states, weight, variance_epsilon)[0]
111+
except ImportError:
112+
raise NotImplementedError(
113+
f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature"
114+
)
115+
return rms_norm_fused(hidden_states, weight, variance_epsilon)
116+
117+
118+
def fusion_flash_attention(
119+
query_states,
120+
config,
121+
key_states,
122+
value_states,
123+
attention_mask,
124+
output_attentions,
125+
alibi=None,
126+
sequence_parallel=False,
127+
reshard_layer=None,
128+
npu_is_casual=False,
129+
):
130+
bsz, q_len, num_heads, head_dim = query_states.shape
131+
_, kv_seq_len, _, _ = value_states.shape
132+
version = paddle.version.full_version
133+
if version != "0.0.0" and version <= "2.5.2":
134+
if alibi is not None:
135+
raise ValueError("Flash Attention doesn't support alibi")
136+
attn_output, attn_weights = flash_attention(
137+
query_states,
138+
key_states,
139+
value_states,
140+
causal=True,
141+
return_softmax=output_attentions,
142+
)
143+
else:
144+
if alibi is not None:
145+
alibi = alibi.reshape([bsz, num_heads, 1, -1])
146+
attention_mask = attention_mask.cast(alibi.dtype) + alibi
147+
if get_env_device() == "npu":
148+
attn_output = core.eager._run_custom_op(
149+
"flash_attention_npu",
150+
query_states,
151+
key_states,
152+
value_states,
153+
None,
154+
attention_mask,
155+
0.0,
156+
attention_mask is None,
157+
True,
158+
False,
159+
npu_is_casual,
160+
)[0]
161+
else:
162+
attn_output = F.scaled_dot_product_attention(
163+
query_states,
164+
key_states,
165+
value_states,
166+
attn_mask=attention_mask,
167+
is_causal=attention_mask is None,
168+
)
169+
attn_weights = None
170+
171+
if reshard_layer is not None:
172+
# attn_output shape: [bs, seqlen, num_head/sep, head_dim]
173+
attn_output = reshard_layer(
174+
attn_output,
175+
split_axis=1,
176+
concat_axis=2,
177+
)
178+
# attn_output shape: [bs, seqlen/sep, num_head, head_dim]
179+
assert (
180+
config.sep_parallel_degree > 1 and q_len % config.sep_parallel_degree == 0
181+
), f"q_len:{q_len}, config.sep_parallel_degree:{config.sep_parallel_degree}"
182+
q_len = q_len // config.sep_parallel_degree
183+
num_heads = num_heads * config.sep_parallel_degree
184+
185+
if sequence_parallel:
186+
attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
187+
else:
188+
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
189+
return (attn_output, attn_weights) if output_attentions else attn_output

paddlenlp/transformers/llama/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def swiglu(x, y=None):
8787
from paddle.nn.functional.flash_attention import flash_attention
8888
except:
8989
flash_attention = None
90-
from .. import fusion_ops
90+
from . import fusion_ops
9191

9292
rms_norm_fused = fusion_ops.rms_norm_fused
9393

0 commit comments

Comments
 (0)