37
37
from paddlenlp .transformers .linear_utils import Linear
38
38
from paddlenlp .transformers .model_outputs import BaseModelOutputWithPast , ModelOutput
39
39
from paddlenlp .transformers .model_utils import PretrainedModel
40
+ from paddlenlp .utils .tools import get_env_device
40
41
41
42
from paddlemix .models .flash_attn_utils import (
42
43
create_attention_module ,
48
49
from .bert_padding import index_first_axis , pad_input , unpad_input
49
50
from .configuration_qwen2_vl import Qwen2VLConfig , Qwen2VLVisionConfig
50
51
52
+ try :
53
+ from paddle .incubate .nn .functional import fused_rotary_position_embedding
54
+ except ImportError :
55
+ fused_rotary_position_embedding = None
56
+
51
57
logger = logging .get_logger (__name__ )
52
58
53
59
flash_attn_func , flash_attn_varlen_func = has_flash_attn_func ()
@@ -407,7 +413,12 @@ def apply_rotary_pos_emb_vision(tensor: paddle.Tensor, freqs: paddle.Tensor) ->
407
413
sin = freqs .sin ()
408
414
cos = cos .unsqueeze (1 ).tile (repeat_times = [1 , 1 , 2 ]).unsqueeze (0 ).astype (dtype = "float32" )
409
415
sin = sin .unsqueeze (1 ).tile (repeat_times = [1 , 1 , 2 ]).unsqueeze (0 ).astype (dtype = "float32" )
410
- output = tensor * cos + rotate_half (tensor ) * sin
416
+ if get_env_device () == "xpu" and fused_rotary_position_embedding is not None :
417
+ output , _ , _ = fused_rotary_position_embedding (
418
+ tensor , sin = sin , cos = cos , use_neox_rotary_style = False
419
+ )
420
+ else :
421
+ output = tensor * cos + rotate_half (tensor ) * sin
411
422
output = paddle .cast (output , orig_dtype )
412
423
return output
413
424
@@ -463,6 +474,12 @@ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> N
463
474
nn .GELU (),
464
475
nn .Linear (self .hidden_size , dim ),
465
476
)
477
+ if get_env_device () == "xpu" :
478
+ self .mlp = nn .Sequential (
479
+ Linear (self .hidden_size , self .hidden_size ),
480
+ nn .GELU (),
481
+ Linear (self .hidden_size , dim ),
482
+ )
466
483
467
484
def forward (self , x : paddle .Tensor ) -> paddle .Tensor :
468
485
x = self .mlp (self .ln_q (x ).reshape ([- 1 , self .hidden_size ]))
@@ -475,6 +492,9 @@ def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
475
492
self .fc1 = nn .Linear (dim , hidden_dim )
476
493
self .act = ACT2FN [hidden_act ]
477
494
self .fc2 = nn .Linear (hidden_dim , dim )
495
+ if get_env_device () == "xpu" :
496
+ self .fc1 = Linear (dim , hidden_dim )
497
+ self .fc2 = Linear (hidden_dim , dim )
478
498
479
499
def forward (self , x ) -> paddle .Tensor :
480
500
return self .fc2 (self .act (self .fc1 (x )))
@@ -486,6 +506,9 @@ def __init__(self, dim: int, num_heads: int = 16) -> None:
486
506
self .num_heads = num_heads
487
507
self .qkv = nn .Linear (dim , dim * 3 , bias_attr = True )
488
508
self .proj = nn .Linear (dim , dim )
509
+ if get_env_device () == "xpu" :
510
+ self .qkv = Linear (dim , dim * 3 , bias_attr = True )
511
+ self .proj = Linear (dim , dim )
489
512
self .head_dim = dim // num_heads # must added
490
513
491
514
def forward (
@@ -525,6 +548,9 @@ def __init__(self, dim: int, num_heads: int = 16) -> None:
525
548
self .num_heads = num_heads
526
549
self .qkv = nn .Linear (dim , dim * 3 , bias_attr = True )
527
550
self .proj = nn .Linear (dim , dim )
551
+ if get_env_device () == "xpu" :
552
+ self .qkv = Linear (dim , dim * 3 , bias_attr = True )
553
+ self .proj = Linear (dim , dim )
528
554
self .head_dim = dim // num_heads # must added
529
555
530
556
def forward (
@@ -657,6 +683,15 @@ def __init__(self, config: Qwen2VLConfig, hidden_size, eps=1e-6):
657
683
self .variance_epsilon = eps
658
684
659
685
def forward (self , hidden_states ):
686
+ if get_env_device () == "xpu" :
687
+ try :
688
+ import paddle_xpu_nn # noqa: F821
689
+
690
+ return paddle_xpu_nn .xpu_rms_norm (hidden_states , self .weight , self .variance_epsilon )[0 ]
691
+ except ImportError :
692
+ raise NotImplementedError (
693
+ f"Implementation of fused_rms_norm is not available on xpu. Please install paddle_xpu to use this feature"
694
+ )
660
695
if paddle .in_dynamic_mode ():
661
696
with paddle .amp .auto_cast (False ):
662
697
variance = hidden_states .astype ("float32" ).pow (2 ).mean (- 1 , keepdim = True )
@@ -1193,7 +1228,7 @@ class Qwen2VLPreTrainedModel(PretrainedModel):
1193
1228
1194
1229
def _init_weights (self , layer ):
1195
1230
std = 0.2
1196
- if isinstance (layer , (nn .Linear , nn .Conv3D )):
1231
+ if isinstance (layer , (nn .Linear , nn .Conv3D , Linear )):
1197
1232
nn .initializer .Normal (mean = 0.0 , std = std )(layer .weight )
1198
1233
if layer .bias is not None :
1199
1234
nn .initializer .Constant (0.0 )(layer .bias )
@@ -1558,6 +1593,9 @@ def __init__(self, config, embedding_weights=None, transpose_y=False):
1558
1593
shape = [config .hidden_size , vocab_size ],
1559
1594
dtype = paddle .get_default_dtype (),
1560
1595
)
1596
+ if get_env_device () == "xpu" :
1597
+ import paddle_xpu .layers .nn .linear as xpu_linear
1598
+ self .xpu_parallel_matmul = xpu_linear .parallel_matmul ()
1561
1599
1562
1600
# Must set distributed attr for Tensor Parallel !
1563
1601
self .weight .is_distributed = True if (vocab_size != config .vocab_size ) else False
@@ -1573,9 +1611,14 @@ def forward(self, hidden_states, tensor_parallel_output=None):
1573
1611
if self .weight .dtype != hidden_states .dtype :
1574
1612
hidden_states = paddle .cast (hidden_states , self .weight .dtype )
1575
1613
1576
- logits = parallel_matmul (
1577
- hidden_states , self .weight , transpose_y = self .transpose_y , tensor_parallel_output = tensor_parallel_output
1578
- )
1614
+ if get_env_device () == "xpu" :
1615
+ logits = self .xpu_parallel_matmul .forward (
1616
+ hidden_states , self .weight , transpose_y = self .transpose_y , tensor_parallel_output = tensor_parallel_output
1617
+ )
1618
+ else :
1619
+ logits = parallel_matmul (
1620
+ hidden_states , self .weight , transpose_y = self .transpose_y , tensor_parallel_output = tensor_parallel_output
1621
+ )
1579
1622
return logits
1580
1623
1581
1624
0 commit comments