@@ -419,11 +419,14 @@ def forward(self, x, seq_len=None):
419
419
return (
420
420
cos .cast (x .dtype ) if cos .dtype != x .dtype else cos ,
421
421
sin .cast (x .dtype ) if sin .dtype != x .dtype else sin ,
422
- self .cos_sin_table .cast (x .dtype )
423
- if self .cos_sin_table is not None and self .cos_sin_table .dtype != x .dtype
424
- else self .cos_sin_table ,
425
422
)
426
423
424
+ def get_fused_cos_sin (self , x , seq_len = None ):
425
+ if self .cos_sin_table is not None and self .cos_sin_table .dtype != x .dtype :
426
+ return self .cos_sin_table .cast (x .dtype )
427
+ else :
428
+ return self .cos_sin_table
429
+
427
430
428
431
class LlamaLinearScalingRotaryEmbedding (LlamaRotaryEmbedding ):
429
432
def __init__ (self , dim , max_position_embeddings = 2048 , base = 10000 , scaling_factor = 1.0 ):
@@ -482,19 +485,26 @@ def _scale_cos_sin(self, seq_len):
482
485
def forward (self , x , seq_len = None ):
483
486
# x: [bs, num_attention_heads, seq_len, head_size]
484
487
if seq_len > self .max_position_embeddings :
485
- scale_cos , scale_sin , scale_cos_sin = self ._scale_cos_sin (seq_len = seq_len )
488
+ scale_cos , scale_sin , _ = self ._scale_cos_sin (seq_len = seq_len )
486
489
else :
487
- scale_cos , scale_sin , scale_cos_sin = self .cos_cached , self .sin_cached , self . cos_sin_table
490
+ scale_cos , scale_sin = self .cos_cached , self .sin_cached
488
491
cos = scale_cos [:, :seq_len , :, ...]
489
492
sin = scale_sin [:, :seq_len , :, ...]
490
493
return (
491
494
cos .cast (x .dtype ) if cos .dtype != x .dtype else cos ,
492
495
sin .cast (x .dtype ) if sin .dtype != x .dtype else sin ,
493
- scale_cos_sin .cast (x .dtype )
494
- if scale_cos_sin is not None and scale_cos_sin .dtype != x .dtype
495
- else scale_cos_sin ,
496
496
)
497
497
498
+ def get_fused_cos_sin (self , x , seq_len = None ):
499
+ if seq_len > self .max_position_embeddings :
500
+ _ , _ , scale_cos_sin = self ._scale_cos_sin (seq_len = seq_len )
501
+ else :
502
+ scale_cos_sin = self .cos_sin_table
503
+ if scale_cos_sin is not None and scale_cos_sin .dtype != x .dtype :
504
+ return scale_cos_sin .cast (x .dtype )
505
+ else :
506
+ return scale_cos_sin
507
+
498
508
499
509
def rotate_half (x ):
500
510
"""Rotates half the hidden dims of the input."""
@@ -943,7 +953,7 @@ def forward(
943
953
sin .cast (value_states .dtype ) if sin .dtype != value_states .dtype else sin ,
944
954
)
945
955
else :
946
- cos , sin , _ = self .rotary_emb (value_states , seq_len = kv_seq_len )
956
+ cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
947
957
948
958
query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
949
959
0 commit comments