Skip to content

support GQA #7906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 23, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,20 +923,44 @@ def forward(
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
if self.use_fused_rope:
assert past_key_value is None, "fuse rotary not support cache kv for now"
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
if get_env_device() == "npu":
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
else:
query_states, key_states, _ = fused_rotary_position_embedding(
query_states,
key_states,
v=None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
# paddle version > 2.6 or develop support q and k/v with different num_heads
paddle_version = float(paddle.__version__[:3])
if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads):
query_states, _, _ = fused_rotary_position_embedding(
query_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
key_states, _, _ = fused_rotary_position_embedding(
key_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
Comment on lines +936 to +953
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GQA的时候,原来的代码 用 fused_rotary_position_embedding 是有问题的吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,Paddle旧版本的fused_rotary_position_embedding不支持传入的q和k/v 有不同的heads,所以等效的方式是单独处理q,k,需要分别调用2次接口。

我们在dev已经做了支持,所以可以直接调用1次接口。

else:
query_states, key_states, _ = fused_rotary_position_embedding(
query_states,
key_states,
v=None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
Expand All @@ -955,8 +979,11 @@ def forward(

# TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# paddle version > 2.6 or develop support flash-attn with gqa/mqa
paddle_version = float(paddle.__version__[:3])
if (paddle_version != 0.0) and (paddle_version <= 2.6):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient)
if (
Expand Down