@@ -1580,3 +1580,298 @@ def fused_rotary_emb(
1580
1580
outputs = {"q_out" : q_out , "k_out" : k_out , "v_out" : v_out },
1581
1581
)
1582
1582
return q_out , k_out , v_out
1583
+
1584
+
1585
+ ########################### split concat ###############################
1586
+ split_concat_template = (
1587
+ """
1588
+ std::vector<paddle::Tensor> ${op_name}_func(
1589
+ const paddle::Tensor &x,
1590
+ const paddle::Tensor &y) {
1591
+
1592
+ int batch = x.dims()[0];
1593
+
1594
+ int seq_qkv = x.dims()[1];
1595
+ int seq_eqkv = y.dims()[1];
1596
+ int output_hidden = x.dims()[2] / 3;
1597
+
1598
+
1599
+ auto qkv = get_tensor_ptr(x);
1600
+ auto eqkv = get_tensor_ptr(y);
1601
+
1602
+
1603
+ auto out0_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place());
1604
+ auto out1_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place());
1605
+ auto out2_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place());
1606
+
1607
+ auto out0 = get_tensor_ptr(out0_tensor);
1608
+ auto out1 = get_tensor_ptr(out1_tensor);
1609
+ auto out2 = get_tensor_ptr(out2_tensor);
1610
+
1611
+
1612
+ auto run_stream = out0_tensor.stream();
1613
+
1614
+ """
1615
+ + tune_and_invoke_part
1616
+ + """
1617
+ return {out0_tensor, out1_tensor, out2_tensor};
1618
+ }
1619
+
1620
+ std::vector<std::vector<int64_t>> ${op_name}_InferShape(
1621
+ const std::vector<int64_t>& A_shape, const std::vector<int64_t>& B_shape) {
1622
+
1623
+ int64_t seq1 = A_shape[1];
1624
+ int64_t seq2 = B_shape[1];
1625
+ int64_t seq = -1;
1626
+ if (seq1 > 0 && seq2 > 0){
1627
+ seq = seq1 + seq2;
1628
+ }
1629
+ std::vector<int64_t> out_shape = {A_shape[0], seq, A_shape[2]/3};
1630
+
1631
+ return {out_shape, out_shape, out_shape};
1632
+ }
1633
+
1634
+ std::vector<paddle::DataType> ${op_name}_InferDtype(const paddle::DataType& A_dtype) {
1635
+ return {A_dtype, A_dtype, A_dtype};
1636
+ }
1637
+
1638
+ PD_BUILD_OP(${op_name})
1639
+ .Inputs({"x", "y"})
1640
+ .Outputs({"out0_tensor", "out1_tensor", "out2_tensor"})
1641
+ .SetKernelFn(PD_KERNEL(${op_name}_func))
1642
+ .SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype))
1643
+ .SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape));
1644
+ """
1645
+ )
1646
+
1647
+
1648
+ @paddle_use_triton (
1649
+ custom_op_template = split_concat_template ,
1650
+ key = ["1" ],
1651
+ )
1652
+ def split_concat_kernel (
1653
+ out0 ,
1654
+ out1 ,
1655
+ out2 ,
1656
+ qkv ,
1657
+ eqkv ,
1658
+ batch ,
1659
+ seq_qkv ,
1660
+ seq_eqkv ,
1661
+ output_hidden ,
1662
+ BLOCK_SIZE : tl .constexpr ,
1663
+ ):
1664
+ out_id = tl .program_id (axis = 0 )
1665
+ batch = tl .program_id (axis = 1 )
1666
+ out_row = tl .program_id (axis = 2 )
1667
+ if out_row < seq_qkv :
1668
+ read_ptr = out_id * output_hidden + out_row * 3 * output_hidden + batch * seq_qkv * output_hidden * 3 + qkv
1669
+ else :
1670
+ read_ptr = (
1671
+ out_id * output_hidden
1672
+ + (out_row - seq_qkv ) * 3 * output_hidden
1673
+ + batch * seq_eqkv * output_hidden * 3
1674
+ + eqkv
1675
+ )
1676
+
1677
+ read_offsets = tl .arange (0 , BLOCK_SIZE )
1678
+ mask = read_offsets < output_hidden
1679
+ read_data = tl .load (read_ptr + read_offsets , mask = mask )
1680
+
1681
+ real_output = out0
1682
+ if out_id == 1 :
1683
+ real_output = out1
1684
+ elif out_id == 2 :
1685
+ real_output = out2
1686
+
1687
+ write_ptr = batch * (seq_qkv + seq_eqkv ) * output_hidden + out_row * output_hidden + real_output + read_offsets
1688
+
1689
+ tl .store (write_ptr , read_data , mask = mask )
1690
+
1691
+
1692
+ def split_concat (x , y ):
1693
+ assert len (x .shape ) == 3
1694
+ assert len (y .shape ) == 3
1695
+
1696
+ assert x .shape [0 ] == y .shape [0 ]
1697
+ assert x .shape [2 ] == y .shape [2 ]
1698
+
1699
+ batch = x .shape [0 ]
1700
+ seq_qkv = x .shape [1 ]
1701
+ hidd_x = x .shape [2 ]
1702
+ seq_eqkv = y .shape [1 ]
1703
+ ouput_hidden = hidd_x // 3
1704
+ BLOCK_SIZE = triton .next_power_of_2 (ouput_hidden )
1705
+ op_name = "split_concat"
1706
+ op_name += get_dtype_str (x .dtype )
1707
+ op_name += f"_{ BLOCK_SIZE } "
1708
+
1709
+ if op_name not in OpProtoHolder .instance ().op_proto_map .keys ():
1710
+ out0 = paddle .empty (shape = [batch , seq_qkv + seq_eqkv , ouput_hidden ], dtype = x .dtype )
1711
+ out1 = paddle .empty (shape = [batch , seq_qkv + seq_eqkv , ouput_hidden ], dtype = x .dtype )
1712
+ out2 = paddle .empty (shape = [batch , seq_qkv + seq_eqkv , ouput_hidden ], dtype = x .dtype )
1713
+ grid = ("3" , "batch" , "seq_qkv + seq_eqkv" )
1714
+
1715
+ split_concat_kernel [(op_name , grid )](
1716
+ out0 , out1 , out2 , x , y , batch , seq_qkv , seq_eqkv , ouput_hidden , BLOCK_SIZE = BLOCK_SIZE
1717
+ )
1718
+
1719
+ if in_dynamic_or_pir_mode ():
1720
+ print (f"== we are in dynamic mode, op_name: { op_name } " )
1721
+ outs = _C_ops ._run_custom_op (
1722
+ op_name ,
1723
+ x ,
1724
+ y ,
1725
+ )
1726
+ return outs [0 ], outs [1 ], outs [2 ]
1727
+ else :
1728
+ print (f"== we are in dynamic to static mode, op_name: { op_name } " )
1729
+ helper = LayerHelper (op_name , ** locals ())
1730
+ inputs = {
1731
+ "x" : x ,
1732
+ "y" : y ,
1733
+ }
1734
+ out0 = helper .create_variable_for_type_inference (dtype = x .dtype )
1735
+ out1 = helper .create_variable_for_type_inference (dtype = x .dtype )
1736
+ out2 = helper .create_variable_for_type_inference (dtype = x .dtype )
1737
+
1738
+ helper .append_op (
1739
+ type = op_name ,
1740
+ inputs = inputs ,
1741
+ outputs = {"out0_tensor" : out0 , "out1_tensor" : out1 , "out2_tensor" : out2 },
1742
+ )
1743
+ return out0 , out1 , out2
1744
+
1745
+
1746
+ ########################### triton split ###############################
1747
+ triton_split_template = (
1748
+ """
1749
+ std::vector<paddle::Tensor> ${op_name}_func(
1750
+ const paddle::Tensor &x,
1751
+ const std::vector<int64_t> num_or_sections,
1752
+ const int64_t axis) {
1753
+
1754
+ int output_batch = x.dims()[0];
1755
+ int output_seq0 = num_or_sections[0];
1756
+ int output_seq1 = num_or_sections[1];
1757
+ int output_hidden = x.dims()[2];
1758
+
1759
+ auto out0_tensor = paddle::empty({output_batch, output_seq0, output_hidden}, x.dtype(), x.place());
1760
+ auto out1_tensor = paddle::empty({output_batch, output_seq1, output_hidden}, x.dtype(), x.place());
1761
+
1762
+ auto out0 = get_tensor_ptr(out0_tensor);
1763
+ auto out1 = get_tensor_ptr(out1_tensor);
1764
+
1765
+ auto input = get_tensor_ptr(x);
1766
+
1767
+ auto run_stream = out0_tensor.stream();
1768
+
1769
+ """
1770
+ + tune_and_invoke_part
1771
+ + """
1772
+ return {out0_tensor, out1_tensor};
1773
+ }
1774
+
1775
+ std::vector<std::vector<int64_t>> ${op_name}_InferShape(
1776
+ const std::vector<int64_t>& A_shape) {
1777
+
1778
+ std::vector<int64_t> out_shape0 = {A_shape[0], 1024, A_shape[2]};
1779
+ std::vector<int64_t> out_shape1 = {A_shape[0], 154, A_shape[2]};
1780
+
1781
+ return {out_shape0, out_shape1};
1782
+ }
1783
+
1784
+ std::vector<paddle::DataType> ${op_name}_InferDtype(const paddle::DataType& A_dtype) {
1785
+ return {A_dtype, A_dtype};
1786
+ }
1787
+
1788
+ PD_BUILD_OP(${op_name})
1789
+ .Inputs({"x"})
1790
+ .Outputs({"out0_tensor", "out1_tensor"})
1791
+ .SetKernelFn(PD_KERNEL(${op_name}_func))
1792
+ .Attrs({"num_or_sections: std::vector<int64_t>", "axis: int64_t"})
1793
+ .SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype))
1794
+ .SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape));
1795
+ """
1796
+ )
1797
+
1798
+
1799
+ @paddle_use_triton (
1800
+ custom_op_template = triton_split_template ,
1801
+ key = ["1" ],
1802
+ )
1803
+ def triton_split_kernel (
1804
+ out0 ,
1805
+ out1 ,
1806
+ input ,
1807
+ output_seq0 ,
1808
+ output_seq1 ,
1809
+ output_batch ,
1810
+ output_hidden ,
1811
+ BLOCK_SIZE : tl .constexpr ,
1812
+ ):
1813
+ batch = tl .program_id (axis = 0 )
1814
+ out_row = tl .program_id (axis = 1 )
1815
+ read_ptr = out_row * output_hidden + batch * (output_seq0 + output_seq1 ) * output_hidden + input
1816
+
1817
+ read_offsets = tl .arange (0 , BLOCK_SIZE )
1818
+ mask = read_offsets < output_hidden
1819
+ read_data = tl .load (read_ptr + read_offsets , mask = mask )
1820
+
1821
+ if out_row < output_seq0 :
1822
+ write_ptr = batch * output_seq0 * output_hidden + out_row * output_hidden + out0 + read_offsets
1823
+ else :
1824
+ write_ptr = batch * output_seq1 * output_hidden + (out_row - output_seq0 ) * output_hidden + out1 + read_offsets
1825
+
1826
+ tl .store (write_ptr , read_data , mask = mask )
1827
+
1828
+
1829
+ def triton_split (x , num_or_sections = [- 1 , - 1 ], axis = 1 ):
1830
+ assert len (x .shape ) == 3
1831
+ output_batch = x .shape [0 ]
1832
+ output_seq0 = num_or_sections [0 ]
1833
+ output_seq1 = num_or_sections [1 ]
1834
+ output_hidden = x .shape [2 ]
1835
+
1836
+ BLOCK_SIZE = triton .next_power_of_2 (output_hidden )
1837
+ op_name = "triton_split"
1838
+ op_name += get_dtype_str (x .dtype )
1839
+ op_name += f"_{ BLOCK_SIZE } "
1840
+
1841
+ if op_name not in OpProtoHolder .instance ().op_proto_map .keys ():
1842
+ out0 = paddle .empty (shape = [output_batch , output_seq0 , output_hidden ], dtype = x .dtype )
1843
+ out1 = paddle .empty (shape = [output_batch , output_seq1 , output_hidden ], dtype = x .dtype )
1844
+ grid = ("output_batch" , "output_seq0+output_seq1" )
1845
+
1846
+ triton_split_kernel [(op_name , grid )](
1847
+ out0 , out1 , x , output_seq0 , output_seq1 , output_batch , output_hidden , BLOCK_SIZE = 2048
1848
+ )
1849
+
1850
+ if in_dynamic_or_pir_mode ():
1851
+ print (f"== we are in dynamic mode, op_name: { op_name } " )
1852
+ outs = _C_ops ._run_custom_op (
1853
+ op_name ,
1854
+ x ,
1855
+ num_or_sections ,
1856
+ axis ,
1857
+ )
1858
+ return outs [0 ], outs [1 ]
1859
+ else :
1860
+ print (f"== we are in dynamic to static mode, op_name: { op_name } " )
1861
+ helper = LayerHelper (op_name , ** locals ())
1862
+ inputs = {
1863
+ "x" : x ,
1864
+ }
1865
+ out0 = helper .create_variable_for_type_inference (dtype = x .dtype )
1866
+ out1 = helper .create_variable_for_type_inference (dtype = x .dtype )
1867
+
1868
+ helper .append_op (
1869
+ type = op_name ,
1870
+ inputs = inputs ,
1871
+ attrs = {
1872
+ "num_or_sections" : num_or_sections ,
1873
+ "axis" : axis ,
1874
+ },
1875
+ outputs = {"out0_tensor" : out0 , "out1_tensor" : out1 },
1876
+ )
1877
+ return out0 , out1
0 commit comments