@@ -413,6 +413,10 @@ def forward(self, hidden_states):
413
413
if self .config .use_fused_rms_norm :
414
414
if get_env_device () == "npu" :
415
415
return core .eager ._run_custom_op ("rms_norm_npu" , hidden_states , self .weight , self .variance_epsilon )[0 ]
416
+ elif get_env_device () == "xpu" :
417
+ import paddle_xpu_nn
418
+
419
+ return paddle_xpu_nn .xpu_rms_norm (hidden_states , self .weight , self .variance_epsilon )[0 ]
416
420
return rms_norm_fused (hidden_states , self .weight , self .variance_epsilon )
417
421
418
422
if paddle .in_dynamic_mode ():
@@ -582,12 +586,33 @@ def __init__(self, config):
582
586
583
587
ColumnParallelLinear = MC2ColumnSeqParallelLinear
584
588
RowParallelLinear = MC2RowSeqParallelLinear
589
+ elif get_env_device () == "xpu" :
590
+ from paddle_xpu .layers .nn .sequence_parallel import ( # noqa: F401
591
+ XPUColumnSequenceParallelLinear ,
592
+ XPURowSequenceParallelLinear ,
593
+ )
594
+
595
+ ColumnParallelLinear = XPUColumnSequenceParallelLinear
596
+ RowParallelLinear = XPURowSequenceParallelLinear
585
597
else :
586
598
ColumnParallelLinear = ColumnSequenceParallelLinear
587
599
RowParallelLinear = RowSequenceParallelLinear
588
600
else :
589
- ColumnParallelLinear = fleet .meta_parallel .ColumnParallelLinear
590
- RowParallelLinear = fleet .meta_parallel .RowParallelLinear
601
+ if get_env_device () == "xpu" :
602
+ import paddle_xpu # noqa: F821
603
+
604
+ ColumnParallelLinear = paddle_xpu .layers .nn .ColumnParallelLinear
605
+ RowParallelLinear = paddle_xpu .layers .nn .RowParallelLinear
606
+ else :
607
+ ColumnParallelLinear = fleet .meta_parallel .ColumnParallelLinear
608
+ RowParallelLinear = fleet .meta_parallel .RowParallelLinear
609
+
610
+ if get_env_device () == "xpu" :
611
+ import paddle_xpu # noqa: F821
612
+
613
+ Linear = paddle_xpu .layers .nn .Linear
614
+ else :
615
+ Linear = nn .Linear
591
616
592
617
if config .tensor_parallel_degree > 1 :
593
618
if config .fuse_attention_ffn :
@@ -619,15 +644,24 @@ def __init__(self, config):
619
644
)
620
645
else :
621
646
if config .fuse_attention_ffn :
622
- self .gate_up_fused_proj = nn . Linear (self .hidden_size , self .intermediate_size * 2 , bias_attr = False )
647
+ self .gate_up_fused_proj = Linear (self .hidden_size , self .intermediate_size * 2 , bias_attr = False )
623
648
else :
624
- self .gate_proj = nn . Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
625
- self .up_proj = nn . Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
649
+ self .gate_proj = Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
650
+ self .up_proj = Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
626
651
627
- self .down_proj = nn . Linear (self .intermediate_size , self .hidden_size , bias_attr = False )
652
+ self .down_proj = Linear (self .intermediate_size , self .hidden_size , bias_attr = False )
628
653
629
654
def forward (self , x ):
630
655
if self .fuse_attention_ffn :
656
+ # FIXME(yangjianbang): use paddle's native swiglu
657
+ if get_env_device () == "xpu" :
658
+ import paddle_xpu_nn # noqa: F821
659
+
660
+ out = self .gate_up_fused_proj (x )
661
+ out = paddle_xpu_nn .xpu_swiglu (out , axis = - 1 , turn = True )
662
+ out = self .down_proj (out )
663
+ return out
664
+
631
665
x = swiglu (self .gate_up_fused_proj (x ))
632
666
else :
633
667
x = swiglu (self .gate_proj (x ), self .up_proj (x ))
@@ -689,7 +723,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
689
723
690
724
self .use_fused_rope = config .use_fused_rope
691
725
if self .use_fused_rope and get_env_device () != "npu" :
692
- if "gpu" not in paddle .device .get_device () or fused_rotary_position_embedding is None :
726
+ if (
727
+ "gpu" not in paddle .device .get_device ()
728
+ or "xpu" not in paddle .device .get_device ()
729
+ or fused_rotary_position_embedding is None
730
+ ):
693
731
warnings .warn (
694
732
"Enable fuse rope in the config, but fuse rope is not available. "
695
733
"Will disable fuse rope. Try using latest gpu version of Paddle."
@@ -705,12 +743,33 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
705
743
706
744
ColumnParallelLinear = MC2ColumnSeqParallelLinear
707
745
RowParallelLinear = MC2RowSeqParallelLinear
746
+ elif get_env_device () == "xpu" :
747
+ from paddle_xpu .layers .nn .sequence_parallel import ( # noqa: F401
748
+ XPUColumnSequenceParallelLinear ,
749
+ XPURowSequenceParallelLinear ,
750
+ )
751
+
752
+ ColumnParallelLinear = XPUColumnSequenceParallelLinear
753
+ RowParallelLinear = XPURowSequenceParallelLinear
708
754
else :
709
755
ColumnParallelLinear = ColumnSequenceParallelLinear
710
756
RowParallelLinear = RowSequenceParallelLinear
711
757
else :
712
- ColumnParallelLinear = fleet .meta_parallel .ColumnParallelLinear
713
- RowParallelLinear = fleet .meta_parallel .RowParallelLinear
758
+ if get_env_device () == "xpu" :
759
+ import paddle_xpu # noqa: F821
760
+
761
+ ColumnParallelLinear = paddle_xpu .layers .nn .ColumnParallelLinear # noqa: F821
762
+ RowParallelLinear = paddle_xpu .layers .nn .RowParallelLinear # noqa: F821
763
+ else :
764
+ ColumnParallelLinear = fleet .meta_parallel .ColumnParallelLinear
765
+ RowParallelLinear = fleet .meta_parallel .RowParallelLinear
766
+
767
+ if get_env_device () == "xpu" :
768
+ import paddle_xpu # noqa: F821
769
+
770
+ Linear = paddle_xpu .layers .nn .Linear
771
+ else :
772
+ Linear = nn .Linear
714
773
715
774
if config .tensor_parallel_degree > 1 :
716
775
if self .fuse_attention_qkv :
@@ -741,36 +800,36 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
741
800
gather_output = False ,
742
801
)
743
802
else :
744
- self .k_proj = nn . Linear (
803
+ self .k_proj = Linear (
745
804
self .hidden_size ,
746
805
self .config .num_key_value_heads * self .head_dim ,
747
806
bias_attr = False ,
748
807
)
749
- self .v_proj = nn . Linear (
808
+ self .v_proj = Linear (
750
809
self .hidden_size ,
751
810
self .config .num_key_value_heads * self .head_dim ,
752
811
bias_attr = False ,
753
812
)
754
813
755
814
else :
756
815
if self .fuse_attention_qkv :
757
- self .qkv_proj = nn . Linear (
816
+ self .qkv_proj = Linear (
758
817
self .hidden_size ,
759
818
self .hidden_size + 2 * self .config .num_key_value_heads * self .head_dim ,
760
819
bias_attr = False ,
761
820
)
762
821
else :
763
- self .q_proj = nn . Linear (
822
+ self .q_proj = Linear (
764
823
self .hidden_size ,
765
824
self .hidden_size ,
766
825
bias_attr = False ,
767
826
)
768
- self .k_proj = nn . Linear (
827
+ self .k_proj = Linear (
769
828
self .hidden_size ,
770
829
self .config .num_key_value_heads * self .head_dim ,
771
830
bias_attr = False ,
772
831
)
773
- self .v_proj = nn . Linear (
832
+ self .v_proj = Linear (
774
833
self .hidden_size ,
775
834
self .config .num_key_value_heads * self .head_dim ,
776
835
bias_attr = False ,
@@ -784,7 +843,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
784
843
input_is_parallel = True ,
785
844
)
786
845
else :
787
- self .o_proj = nn . Linear (
846
+ self .o_proj = Linear (
788
847
self .hidden_size ,
789
848
self .hidden_size ,
790
849
bias_attr = False ,
@@ -1428,6 +1487,11 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
1428
1487
y = paddle .to_tensor (paddle .finfo (dtype ).min , dtype = "float16" )
1429
1488
expanded_attn_mask = expanded_attn_mask .astype ("float16" )
1430
1489
expanded_attn_mask = paddle .where (expanded_attn_mask , x , y ).astype (dtype )
1490
+ elif get_env_device () == "xpu" :
1491
+ x = paddle .to_tensor (0.0 , dtype = dtype )
1492
+ y = paddle .to_tensor (paddle .finfo (dtype ).min , dtype = dtype )
1493
+ expanded_attn_mask = expanded_attn_mask .astype (dtype )
1494
+ expanded_attn_mask = paddle .where (expanded_attn_mask , x , y ).astype (dtype )
1431
1495
else :
1432
1496
expanded_attn_mask = paddle .where (expanded_attn_mask , 0.0 , paddle .finfo (dtype ).min ).astype (dtype )
1433
1497
return expanded_attn_mask
@@ -1708,6 +1772,10 @@ def __init__(self, config: LlamaConfig):
1708
1772
self .weight .is_distributed = True if (vocab_size != config .vocab_size ) else False
1709
1773
if self .weight .is_distributed :
1710
1774
self .weight .split_axis = 1
1775
+ if get_env_device () == "xpu" :
1776
+ import paddle_xpu
1777
+
1778
+ self .xpu_parallel_matmul = paddle_xpu .layers .nn .parallel_matmul ()
1711
1779
1712
1780
def forward (self , hidden_states , tensor_parallel_output = None ):
1713
1781
if self .config .sequence_parallel :
@@ -1721,7 +1789,12 @@ def forward(self, hidden_states, tensor_parallel_output=None):
1721
1789
if tensor_parallel_output is None :
1722
1790
tensor_parallel_output = self .config .tensor_parallel_output
1723
1791
1724
- logits = parallel_matmul (hidden_states , self .weight , tensor_parallel_output = tensor_parallel_output )
1792
+ if get_env_device () == "xpu" :
1793
+ logits = self .xpu_parallel_matmul (
1794
+ hidden_states , self .weight , tensor_parallel_output = tensor_parallel_output , training = self .training
1795
+ )
1796
+ else :
1797
+ logits = parallel_matmul (hidden_states , self .weight , tensor_parallel_output = tensor_parallel_output )
1725
1798
return logits
1726
1799
1727
1800
0 commit comments