@@ -79,7 +79,7 @@ def swiglu(x, y=None):
79
79
)
80
80
81
81
try :
82
- if get_env_device () == "npu" :
82
+ if get_env_device () in [ "npu" , "gcu" ] :
83
83
84
84
for lib in os .listdir (os .getenv ("CUSTOM_DEVICE_ROOT" )):
85
85
if lib .endswith (".so" ):
@@ -410,6 +410,7 @@ def _set_cos_sin_cache(self, seq_len):
410
410
# [1, seqlen, 1, dim]
411
411
self .cos_cached = emb .cos ()[None , :, None , :]
412
412
self .sin_cached = emb .sin ()[None , :, None , :]
413
+ self .cos_sin_table = None if get_env_device () != "gcu" else paddle .concat ([freqs .cos (), freqs .sin ()], axis = - 1 )
413
414
414
415
def forward (self , x , seq_len = None ):
415
416
# x: [bs, num_attention_heads, seq_len, head_size]
@@ -418,6 +419,9 @@ def forward(self, x, seq_len=None):
418
419
return (
419
420
cos .cast (x .dtype ) if cos .dtype != x .dtype else cos ,
420
421
sin .cast (x .dtype ) if sin .dtype != x .dtype else sin ,
422
+ self .cos_sin_table .cast (x .dtype )
423
+ if self .cos_sin_table is not None and self .cos_sin_table .dtype != x .dtype
424
+ else self .cos_sin_table ,
421
425
)
422
426
423
427
@@ -439,6 +443,7 @@ def _set_cos_sin_cache(self, seq_len):
439
443
# [1, seqlen, 1, dim]
440
444
self .cos_cached = emb .cos ()[None , :, None , :]
441
445
self .sin_cached = emb .sin ()[None , :, None , :]
446
+ self .cos_sin_table = None if get_env_device () != "gcu" else paddle .concat ([freqs .cos (), freqs .sin ()], axis = - 1 )
442
447
443
448
444
449
class LlamaNTKScalingRotaryEmbedding (LlamaRotaryEmbedding ):
@@ -471,19 +476,23 @@ def _scale_cos_sin(self, seq_len):
471
476
# [1, seqlen, 1, dim]
472
477
scale_cos = emb .cos ()[None , :, None , :]
473
478
scale_sin = emb .sin ()[None , :, None , :]
474
- return scale_cos , scale_sin
479
+ scale_cos_sin = None if get_env_device () != "gcu" else paddle .concat ([freqs .cos (), freqs .sin ()], axis = - 1 )
480
+ return scale_cos , scale_sin , scale_cos_sin
475
481
476
482
def forward (self , x , seq_len = None ):
477
483
# x: [bs, num_attention_heads, seq_len, head_size]
478
484
if seq_len > self .max_position_embeddings :
479
- scale_cos , scale_sin = self ._scale_cos_sin (seq_len = seq_len )
485
+ scale_cos , scale_sin , scale_cos_sin = self ._scale_cos_sin (seq_len = seq_len )
480
486
else :
481
- scale_cos , scale_sin = self .cos_cached , self .sin_cached
487
+ scale_cos , scale_sin , scale_cos_sin = self .cos_cached , self .sin_cached , self . cos_sin_table
482
488
cos = scale_cos [:, :seq_len , :, ...]
483
489
sin = scale_sin [:, :seq_len , :, ...]
484
490
return (
485
491
cos .cast (x .dtype ) if cos .dtype != x .dtype else cos ,
486
492
sin .cast (x .dtype ) if sin .dtype != x .dtype else sin ,
493
+ scale_cos_sin .cast (x .dtype )
494
+ if scale_cos_sin is not None and scale_cos_sin .dtype != x .dtype
495
+ else scale_cos_sin ,
487
496
)
488
497
489
498
@@ -638,7 +647,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
638
647
)
639
648
640
649
self .use_fused_rope = config .use_fused_rope
641
- if self .use_fused_rope and get_env_device () not in ["npu" , "xpu" ]:
650
+ if self .use_fused_rope and get_env_device () not in ["npu" , "xpu" , "gcu" ]:
642
651
if "gpu" not in paddle .device .get_device () or fused_rotary_position_embedding is None :
643
652
warnings .warn (
644
653
"Enable fuse rope in the config, but fuse rope is not available. "
@@ -934,7 +943,7 @@ def forward(
934
943
sin .cast (value_states .dtype ) if sin .dtype != value_states .dtype else sin ,
935
944
)
936
945
else :
937
- cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
946
+ cos , sin , _ = self .rotary_emb (value_states , seq_len = kv_seq_len )
938
947
939
948
query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
940
949
@@ -1398,7 +1407,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
1398
1407
y = paddle .to_tensor (paddle .finfo (dtype ).min , dtype = "float32" )
1399
1408
expanded_attn_mask = expanded_attn_mask .astype ("float32" )
1400
1409
expanded_attn_mask = paddle .where (expanded_attn_mask , x , y ).astype (dtype )
1401
- elif get_env_device () == "xpu" :
1410
+ elif get_env_device () in [ "xpu" , "gcu" ] :
1402
1411
x = paddle .to_tensor (0.0 , dtype = dtype )
1403
1412
y = paddle .to_tensor (paddle .finfo (dtype ).min , dtype = dtype )
1404
1413
expanded_attn_mask = expanded_attn_mask .astype (dtype )
@@ -1528,7 +1537,7 @@ def forward(
1528
1537
attention_mask , (batch_size , seq_length ), cache_length , inputs_embeds .dtype
1529
1538
) # [bs, 1, seq_len, seq_len]
1530
1539
is_casual = False
1531
- if self .config .use_flash_attention :
1540
+ if self .config .use_flash_attention and get_env_device () != "gcu" :
1532
1541
is_casual = is_casual_mask (attention_mask )
1533
1542
if get_env_device () != "npu" :
1534
1543
if is_casual and alibi is None :
0 commit comments