Skip to content

Commit 32d66ef

Browse files
committed
[GCU] Support llama for GCU
1 parent 5170664 commit 32d66ef

File tree

5 files changed

+51
-14
lines changed

5 files changed

+51
-14
lines changed

examples/benchmark/wiki_lambada/eval.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def get_parser():
5757
"--device",
5858
type=str,
5959
default="gpu",
60-
choices=["cpu", "eval_pathgpu", "xpu", "npu"],
61-
help="select cpu, gpu, xpu devices.",
60+
choices=["cpu", "gpu", "xpu", "npu", "gcu"],
61+
help="select cpu, gpu, xpu, gcu devices.",
6262
)
6363
parser.add_argument(
6464
"--dtype",
@@ -67,6 +67,12 @@ def get_parser():
6767
choices=["bfloat16", "float16", "float32"],
6868
help="set the dtype of model",
6969
)
70+
parser.add_argument(
71+
"--use_flash_attention",
72+
type=bool,
73+
default=False,
74+
help="Whether to use flash attention",
75+
)
7076

7177
# load autodist name files, eg: bloom-176b
7278
parser.add_argument("--load_autodist", action="store_true", help="whether load auto-dist wieght file")
@@ -316,7 +322,7 @@ def do_generation():
316322
tensor_parallel_output=False,
317323
tensor_parallel_degree=args.tensor_parallel_degree,
318324
tensor_parallel_rank=paddle.distributed.get_rank(),
319-
use_flash_attention=False,
325+
use_flash_attention=args.use_flash_attention,
320326
dtype=args.dtype, # todo enable set dtype to avoid additional mem usage
321327
)
322328

paddlenlp/generation/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,8 @@ def sample(
12081208
probs = TopKProcess(probs, top_k, min_tokens_to_keep)
12091209
if top_p is not None and top_p < 1.0:
12101210
probs = TopPProcess(probs, top_p, min_tokens_to_keep)
1211+
if paddle.device.is_compiled_with_custom_device("gcu"):
1212+
probs = paddle.cast(probs, "float32")
12111213

12121214
# multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
12131215
next_tokens = paddle.multinomial(probs)

paddlenlp/transformers/llama/fusion_ops.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def swiglu(x, y=None):
4141
except ImportError:
4242
fused_rotary_position_embedding = None
4343
try:
44-
if get_env_device() == "npu":
44+
if get_env_device() in ["npu", "gcu"]:
4545
from paddle.base import core
4646

4747
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
@@ -53,13 +53,18 @@ def swiglu(x, y=None):
5353

5454

5555
def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb):
56-
assert past_key_value is None, "fuse rotary not support cache kv for now"
56+
if get_env_device() != "gcu":
57+
assert past_key_value is None, "fuse rotary not support cache kv for now"
5758
batch_size, seq_length, num_heads, head_dim = query_states.shape
5859
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
59-
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
60+
cos, sin, cos_sin = rotary_emb(value_states, seq_len=kv_seq_len)
6061
if get_env_device() == "npu":
6162
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
6263
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
64+
elif get_env_device() == "gcu":
65+
query_states, key_states = core.eager._run_custom_op(
66+
"fused_rotary_embedding_gcu", query_states, key_states, cos_sin, position_ids, True
67+
)
6368
else:
6469
# paddle version > 2.6 or develop support q and k/v with different num_heads
6570
paddle_version = float(paddle.__version__[:3])
@@ -103,6 +108,8 @@ def rms_norm_fused(x_in, w, eps):
103108
def fusion_rms_norm(hidden_states, weight, variance_epsilon):
104109
if get_env_device() == "npu":
105110
return core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0]
111+
elif get_env_device() == "gcu":
112+
return core.eager._run_custom_op("rms_norm_gcu", hidden_states, weight, variance_epsilon)[0]
106113
elif get_env_device() == "xpu":
107114
try:
108115
import paddle_xpu_nn # noqa: F821
@@ -158,6 +165,17 @@ def fusion_flash_attention(
158165
False,
159166
npu_is_casual,
160167
)[0]
168+
elif get_env_device() == "gcu":
169+
attn_output = core.eager._run_custom_op(
170+
"fused_sdp_flash_attention_gcu",
171+
query_states,
172+
key_states,
173+
value_states,
174+
attention_mask,
175+
0.0,
176+
attention_mask is None,
177+
True,
178+
)[0]
161179
else:
162180
attn_output = F.scaled_dot_product_attention(
163181
query_states,

paddlenlp/transformers/llama/modeling.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def swiglu(x, y=None):
7979
)
8080

8181
try:
82-
if get_env_device() == "npu":
82+
if get_env_device() in ["npu", "gcu"]:
8383

8484
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
8585
if lib.endswith(".so"):
@@ -410,6 +410,7 @@ def _set_cos_sin_cache(self, seq_len):
410410
# [1, seqlen, 1, dim]
411411
self.cos_cached = emb.cos()[None, :, None, :]
412412
self.sin_cached = emb.sin()[None, :, None, :]
413+
self.cos_sin_table = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
413414

414415
def forward(self, x, seq_len=None):
415416
# x: [bs, num_attention_heads, seq_len, head_size]
@@ -418,6 +419,9 @@ def forward(self, x, seq_len=None):
418419
return (
419420
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
420421
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
422+
self.cos_sin_table.cast(x.dtype)
423+
if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype
424+
else self.cos_sin_table,
421425
)
422426

423427

@@ -439,6 +443,7 @@ def _set_cos_sin_cache(self, seq_len):
439443
# [1, seqlen, 1, dim]
440444
self.cos_cached = emb.cos()[None, :, None, :]
441445
self.sin_cached = emb.sin()[None, :, None, :]
446+
self.cos_sin_table = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
442447

443448

444449
class LlamaNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
@@ -471,19 +476,23 @@ def _scale_cos_sin(self, seq_len):
471476
# [1, seqlen, 1, dim]
472477
scale_cos = emb.cos()[None, :, None, :]
473478
scale_sin = emb.sin()[None, :, None, :]
474-
return scale_cos, scale_sin
479+
scale_cos_sin = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
480+
return scale_cos, scale_sin, scale_cos_sin
475481

476482
def forward(self, x, seq_len=None):
477483
# x: [bs, num_attention_heads, seq_len, head_size]
478484
if seq_len > self.max_position_embeddings:
479-
scale_cos, scale_sin = self._scale_cos_sin(seq_len=seq_len)
485+
scale_cos, scale_sin, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len)
480486
else:
481-
scale_cos, scale_sin = self.cos_cached, self.sin_cached
487+
scale_cos, scale_sin, scale_cos_sin = self.cos_cached, self.sin_cached, self.cos_sin_table
482488
cos = scale_cos[:, :seq_len, :, ...]
483489
sin = scale_sin[:, :seq_len, :, ...]
484490
return (
485491
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
486492
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
493+
scale_cos_sin.cast(x.dtype)
494+
if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype
495+
else scale_cos_sin,
487496
)
488497

489498

@@ -638,7 +647,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
638647
)
639648

640649
self.use_fused_rope = config.use_fused_rope
641-
if self.use_fused_rope and get_env_device() not in ["npu", "xpu"]:
650+
if self.use_fused_rope and get_env_device() not in ["npu", "xpu", "gcu"]:
642651
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
643652
warnings.warn(
644653
"Enable fuse rope in the config, but fuse rope is not available. "
@@ -934,7 +943,7 @@ def forward(
934943
sin.cast(value_states.dtype) if sin.dtype != value_states.dtype else sin,
935944
)
936945
else:
937-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
946+
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len)
938947

939948
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
940949

@@ -1398,7 +1407,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
13981407
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
13991408
expanded_attn_mask = expanded_attn_mask.astype("float32")
14001409
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
1401-
elif get_env_device() == "xpu":
1410+
elif get_env_device() in ["xpu", "gcu"]:
14021411
x = paddle.to_tensor(0.0, dtype=dtype)
14031412
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype)
14041413
expanded_attn_mask = expanded_attn_mask.astype(dtype)
@@ -1528,7 +1537,7 @@ def forward(
15281537
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
15291538
) # [bs, 1, seq_len, seq_len]
15301539
is_casual = False
1531-
if self.config.use_flash_attention:
1540+
if self.config.use_flash_attention and get_env_device() != "gcu":
15321541
is_casual = is_casual_mask(attention_mask)
15331542
if get_env_device() != "npu":
15341543
if is_casual and alibi is None:

paddlenlp/utils/tools.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def get_env_device():
124124
return "gpu"
125125
elif "npu" in paddle.device.get_all_custom_device_type():
126126
return "npu"
127+
elif "gcu" in paddle.device.get_all_custom_device_type():
128+
return "gcu"
127129
elif paddle.is_compiled_with_rocm():
128130
return "rocm"
129131
elif paddle.is_compiled_with_xpu():

0 commit comments

Comments
 (0)