Skip to content

Commit 4b86c51

Browse files
SD3 transformer部分的优化 (#713)
Co-authored-by: changwenbin <changwenbin@baidu.com>
1 parent 2962cab commit 4b86c51

File tree

8 files changed

+765
-59
lines changed

8 files changed

+765
-59
lines changed

paddlemix/triton_ops/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
fused_rotary_emb,
2222
paddle_use_triton,
2323
rms_norm,
24+
split_concat,
25+
triton_split,
2426
weight_only_int8,
2527
)
2628
from .triton_utils import (
@@ -39,6 +41,8 @@
3941
"rms_norm",
4042
"get_dtype_str",
4143
"fused_rotary_emb",
44+
"split_concat",
45+
"triton_split",
4246
]
4347
except:
4448
pass

paddlemix/triton_ops/triton_ops.py

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,3 +1580,298 @@ def fused_rotary_emb(
15801580
outputs={"q_out": q_out, "k_out": k_out, "v_out": v_out},
15811581
)
15821582
return q_out, k_out, v_out
1583+
1584+
1585+
########################### split concat ###############################
1586+
split_concat_template = (
1587+
"""
1588+
std::vector<paddle::Tensor> ${op_name}_func(
1589+
const paddle::Tensor &x,
1590+
const paddle::Tensor &y) {
1591+
1592+
int batch = x.dims()[0];
1593+
1594+
int seq_qkv = x.dims()[1];
1595+
int seq_eqkv = y.dims()[1];
1596+
int output_hidden = x.dims()[2] / 3;
1597+
1598+
1599+
auto qkv = get_tensor_ptr(x);
1600+
auto eqkv = get_tensor_ptr(y);
1601+
1602+
1603+
auto out0_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place());
1604+
auto out1_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place());
1605+
auto out2_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place());
1606+
1607+
auto out0 = get_tensor_ptr(out0_tensor);
1608+
auto out1 = get_tensor_ptr(out1_tensor);
1609+
auto out2 = get_tensor_ptr(out2_tensor);
1610+
1611+
1612+
auto run_stream = out0_tensor.stream();
1613+
1614+
"""
1615+
+ tune_and_invoke_part
1616+
+ """
1617+
return {out0_tensor, out1_tensor, out2_tensor};
1618+
}
1619+
1620+
std::vector<std::vector<int64_t>> ${op_name}_InferShape(
1621+
const std::vector<int64_t>& A_shape, const std::vector<int64_t>& B_shape) {
1622+
1623+
int64_t seq1 = A_shape[1];
1624+
int64_t seq2 = B_shape[1];
1625+
int64_t seq = -1;
1626+
if (seq1 > 0 && seq2 > 0){
1627+
seq = seq1 + seq2;
1628+
}
1629+
std::vector<int64_t> out_shape = {A_shape[0], seq, A_shape[2]/3};
1630+
1631+
return {out_shape, out_shape, out_shape};
1632+
}
1633+
1634+
std::vector<paddle::DataType> ${op_name}_InferDtype(const paddle::DataType& A_dtype) {
1635+
return {A_dtype, A_dtype, A_dtype};
1636+
}
1637+
1638+
PD_BUILD_OP(${op_name})
1639+
.Inputs({"x", "y"})
1640+
.Outputs({"out0_tensor", "out1_tensor", "out2_tensor"})
1641+
.SetKernelFn(PD_KERNEL(${op_name}_func))
1642+
.SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype))
1643+
.SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape));
1644+
"""
1645+
)
1646+
1647+
1648+
@paddle_use_triton(
1649+
custom_op_template=split_concat_template,
1650+
key=["1"],
1651+
)
1652+
def split_concat_kernel(
1653+
out0,
1654+
out1,
1655+
out2,
1656+
qkv,
1657+
eqkv,
1658+
batch,
1659+
seq_qkv,
1660+
seq_eqkv,
1661+
output_hidden,
1662+
BLOCK_SIZE: tl.constexpr,
1663+
):
1664+
out_id = tl.program_id(axis=0)
1665+
batch = tl.program_id(axis=1)
1666+
out_row = tl.program_id(axis=2)
1667+
if out_row < seq_qkv:
1668+
read_ptr = out_id * output_hidden + out_row * 3 * output_hidden + batch * seq_qkv * output_hidden * 3 + qkv
1669+
else:
1670+
read_ptr = (
1671+
out_id * output_hidden
1672+
+ (out_row - seq_qkv) * 3 * output_hidden
1673+
+ batch * seq_eqkv * output_hidden * 3
1674+
+ eqkv
1675+
)
1676+
1677+
read_offsets = tl.arange(0, BLOCK_SIZE)
1678+
mask = read_offsets < output_hidden
1679+
read_data = tl.load(read_ptr + read_offsets, mask=mask)
1680+
1681+
real_output = out0
1682+
if out_id == 1:
1683+
real_output = out1
1684+
elif out_id == 2:
1685+
real_output = out2
1686+
1687+
write_ptr = batch * (seq_qkv + seq_eqkv) * output_hidden + out_row * output_hidden + real_output + read_offsets
1688+
1689+
tl.store(write_ptr, read_data, mask=mask)
1690+
1691+
1692+
def split_concat(x, y):
1693+
assert len(x.shape) == 3
1694+
assert len(y.shape) == 3
1695+
1696+
assert x.shape[0] == y.shape[0]
1697+
assert x.shape[2] == y.shape[2]
1698+
1699+
batch = x.shape[0]
1700+
seq_qkv = x.shape[1]
1701+
hidd_x = x.shape[2]
1702+
seq_eqkv = y.shape[1]
1703+
ouput_hidden = hidd_x // 3
1704+
BLOCK_SIZE = triton.next_power_of_2(ouput_hidden)
1705+
op_name = "split_concat"
1706+
op_name += get_dtype_str(x.dtype)
1707+
op_name += f"_{BLOCK_SIZE}"
1708+
1709+
if op_name not in OpProtoHolder.instance().op_proto_map.keys():
1710+
out0 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype)
1711+
out1 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype)
1712+
out2 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype)
1713+
grid = ("3", "batch", "seq_qkv + seq_eqkv")
1714+
1715+
split_concat_kernel[(op_name, grid)](
1716+
out0, out1, out2, x, y, batch, seq_qkv, seq_eqkv, ouput_hidden, BLOCK_SIZE=BLOCK_SIZE
1717+
)
1718+
1719+
if in_dynamic_or_pir_mode():
1720+
print(f"== we are in dynamic mode, op_name: {op_name}")
1721+
outs = _C_ops._run_custom_op(
1722+
op_name,
1723+
x,
1724+
y,
1725+
)
1726+
return outs[0], outs[1], outs[2]
1727+
else:
1728+
print(f"== we are in dynamic to static mode, op_name: {op_name}")
1729+
helper = LayerHelper(op_name, **locals())
1730+
inputs = {
1731+
"x": x,
1732+
"y": y,
1733+
}
1734+
out0 = helper.create_variable_for_type_inference(dtype=x.dtype)
1735+
out1 = helper.create_variable_for_type_inference(dtype=x.dtype)
1736+
out2 = helper.create_variable_for_type_inference(dtype=x.dtype)
1737+
1738+
helper.append_op(
1739+
type=op_name,
1740+
inputs=inputs,
1741+
outputs={"out0_tensor": out0, "out1_tensor": out1, "out2_tensor": out2},
1742+
)
1743+
return out0, out1, out2
1744+
1745+
1746+
########################### triton split ###############################
1747+
triton_split_template = (
1748+
"""
1749+
std::vector<paddle::Tensor> ${op_name}_func(
1750+
const paddle::Tensor &x,
1751+
const std::vector<int64_t> num_or_sections,
1752+
const int64_t axis) {
1753+
1754+
int output_batch = x.dims()[0];
1755+
int output_seq0 = num_or_sections[0];
1756+
int output_seq1 = num_or_sections[1];
1757+
int output_hidden = x.dims()[2];
1758+
1759+
auto out0_tensor = paddle::empty({output_batch, output_seq0, output_hidden}, x.dtype(), x.place());
1760+
auto out1_tensor = paddle::empty({output_batch, output_seq1, output_hidden}, x.dtype(), x.place());
1761+
1762+
auto out0 = get_tensor_ptr(out0_tensor);
1763+
auto out1 = get_tensor_ptr(out1_tensor);
1764+
1765+
auto input = get_tensor_ptr(x);
1766+
1767+
auto run_stream = out0_tensor.stream();
1768+
1769+
"""
1770+
+ tune_and_invoke_part
1771+
+ """
1772+
return {out0_tensor, out1_tensor};
1773+
}
1774+
1775+
std::vector<std::vector<int64_t>> ${op_name}_InferShape(
1776+
const std::vector<int64_t>& A_shape) {
1777+
1778+
std::vector<int64_t> out_shape0 = {A_shape[0], 1024, A_shape[2]};
1779+
std::vector<int64_t> out_shape1 = {A_shape[0], 154, A_shape[2]};
1780+
1781+
return {out_shape0, out_shape1};
1782+
}
1783+
1784+
std::vector<paddle::DataType> ${op_name}_InferDtype(const paddle::DataType& A_dtype) {
1785+
return {A_dtype, A_dtype};
1786+
}
1787+
1788+
PD_BUILD_OP(${op_name})
1789+
.Inputs({"x"})
1790+
.Outputs({"out0_tensor", "out1_tensor"})
1791+
.SetKernelFn(PD_KERNEL(${op_name}_func))
1792+
.Attrs({"num_or_sections: std::vector<int64_t>", "axis: int64_t"})
1793+
.SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype))
1794+
.SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape));
1795+
"""
1796+
)
1797+
1798+
1799+
@paddle_use_triton(
1800+
custom_op_template=triton_split_template,
1801+
key=["1"],
1802+
)
1803+
def triton_split_kernel(
1804+
out0,
1805+
out1,
1806+
input,
1807+
output_seq0,
1808+
output_seq1,
1809+
output_batch,
1810+
output_hidden,
1811+
BLOCK_SIZE: tl.constexpr,
1812+
):
1813+
batch = tl.program_id(axis=0)
1814+
out_row = tl.program_id(axis=1)
1815+
read_ptr = out_row * output_hidden + batch * (output_seq0 + output_seq1) * output_hidden + input
1816+
1817+
read_offsets = tl.arange(0, BLOCK_SIZE)
1818+
mask = read_offsets < output_hidden
1819+
read_data = tl.load(read_ptr + read_offsets, mask=mask)
1820+
1821+
if out_row < output_seq0:
1822+
write_ptr = batch * output_seq0 * output_hidden + out_row * output_hidden + out0 + read_offsets
1823+
else:
1824+
write_ptr = batch * output_seq1 * output_hidden + (out_row - output_seq0) * output_hidden + out1 + read_offsets
1825+
1826+
tl.store(write_ptr, read_data, mask=mask)
1827+
1828+
1829+
def triton_split(x, num_or_sections=[-1, -1], axis=1):
1830+
assert len(x.shape) == 3
1831+
output_batch = x.shape[0]
1832+
output_seq0 = num_or_sections[0]
1833+
output_seq1 = num_or_sections[1]
1834+
output_hidden = x.shape[2]
1835+
1836+
BLOCK_SIZE = triton.next_power_of_2(output_hidden)
1837+
op_name = "triton_split"
1838+
op_name += get_dtype_str(x.dtype)
1839+
op_name += f"_{BLOCK_SIZE}"
1840+
1841+
if op_name not in OpProtoHolder.instance().op_proto_map.keys():
1842+
out0 = paddle.empty(shape=[output_batch, output_seq0, output_hidden], dtype=x.dtype)
1843+
out1 = paddle.empty(shape=[output_batch, output_seq1, output_hidden], dtype=x.dtype)
1844+
grid = ("output_batch", "output_seq0+output_seq1")
1845+
1846+
triton_split_kernel[(op_name, grid)](
1847+
out0, out1, x, output_seq0, output_seq1, output_batch, output_hidden, BLOCK_SIZE=2048
1848+
)
1849+
1850+
if in_dynamic_or_pir_mode():
1851+
print(f"== we are in dynamic mode, op_name: {op_name}")
1852+
outs = _C_ops._run_custom_op(
1853+
op_name,
1854+
x,
1855+
num_or_sections,
1856+
axis,
1857+
)
1858+
return outs[0], outs[1]
1859+
else:
1860+
print(f"== we are in dynamic to static mode, op_name: {op_name}")
1861+
helper = LayerHelper(op_name, **locals())
1862+
inputs = {
1863+
"x": x,
1864+
}
1865+
out0 = helper.create_variable_for_type_inference(dtype=x.dtype)
1866+
out1 = helper.create_variable_for_type_inference(dtype=x.dtype)
1867+
1868+
helper.append_op(
1869+
type=op_name,
1870+
inputs=inputs,
1871+
attrs={
1872+
"num_or_sections": num_or_sections,
1873+
"axis": axis,
1874+
},
1875+
outputs={"out0_tensor": out0, "out1_tensor": out1},
1876+
)
1877+
return out0, out1

ppdiffusers/deploy/sd3/README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Stable Diffusion 3 高性能推理
2+
3+
- Paddle Inference提供Stable Diffusion 3 模型高性能推理实现,推理性能提升70%+
4+
环境准备:
5+
```shell
6+
# 安装 triton并适配paddle
7+
python -m pip install triton
8+
python -m pip install git+https://github.com/zhoutianzi666/UseTritonInPaddle.git
9+
python -c "import use_triton_in_paddle; use_triton_in_paddle.make_triton_compatible_with_paddle()"
10+
11+
# 安装develop版本的paddle,请根据自己的cuda版本选择对应的paddle版本,这里选择12.3的cuda版本
12+
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/
13+
14+
# 指定 libCutlassGemmEpilogue.so 的路径
15+
# 详情请参考 https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/fusion/cutlass/gemm_epilogue/README.md
16+
export LD_LIBRARY_PATH=/your_dir/Paddle/paddle/phi/kernels/fusion/cutlass/gemm_epilogue/build:$LD_LIBRARY_PATH
17+
```
18+
19+
高性能推理指令:
20+
```shell
21+
# 执行FP16推理
22+
python text_to_image_generation-stable_diffusion_3.py --dtype float16 --height 512 --width 512 \
23+
--num-inference-steps 50 --inference_optimize 1 \
24+
--benchmark 1
25+
```
26+
27+
- 在 NVIDIA A100-SXM4-40GB 上测试的性能如下:
28+
29+
| Paddle Inference| PyTorch | Paddle 动态图 |
30+
| --------------- | ------------ | ------------ |
31+
| 1.2 s | 1.78 s | 4.202 s |

0 commit comments

Comments
 (0)