Skip to content

update DIT doc #693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 30, 2024
Merged
23 changes: 18 additions & 5 deletions paddlemix/triton_ops/triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,13 @@ def paddle_fused_adaLN(x, mha_out, gate, hidd, scale, shift, weight, bias, epsil
seq_size = x.shape[1]
N_npo2 = triton.next_power_of_2(N)

# baseline.
if os.getenv("INFERENCE_OPTIMIZE_TRITON") is None:
resi_out_paddle = mha_out * gate_msa.unsqueeze(axis=1) + x
norm_hidden_states = paddle.nn.functional.layer_norm(resi_out_paddle, [N], weight, bias, epsilon)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
return resi_out_paddle, norm_hidden_states

op_name = "triton_fused_adaLN_scale_residual"
op_name += get_dtype_str(x.dtype)
op_name += f"_{N_npo2}_{weight_attr}_{bias_attr}"
Expand All @@ -865,9 +872,9 @@ def paddle_fused_adaLN(x, mha_out, gate, hidd, scale, shift, weight, bias, epsil
shift_mlp,
resi_out,
adaLN_out,
M,
-1,
N,
seq_size,
-1,
epsilon,
N_npo2=N_npo2,
weight_attr=weight_attr,
Expand Down Expand Up @@ -1072,7 +1079,13 @@ def modulate(x, shift, scale):
M = x.shape[0] * x.shape[1]
N = x.shape[2]
seq_size = x.shape[1]
BLOCK_SIZE = min(1024, triton.next_power_of_2(N))
BLOCK_SIZE = triton.next_power_of_2(N)

# baseline.
if os.getenv("INFERENCE_OPTIMIZE_TRITON") is None:
norm_hidden_states = paddle.nn.functional.layer_norm(x, [N], weight, bias, epsilon)
norm_hidden_states = norm_hidden_states * (1 + scale[:, None]) + shift[:, None]
return norm_hidden_states

op_name = "triton_adaptive_layer_norm"
op_name += get_dtype_str(x.dtype)
Expand All @@ -1096,9 +1109,9 @@ def modulate(x, shift, scale):
y,
y,
y,
M,
-1,
N,
seq_size,
-1,
epsilon,
BLOCK_SIZE=BLOCK_SIZE,
weight_attr=weight_attr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,29 @@ image = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator
image.save("result_DiT_golden_retriever.png")
```

### 2.3 Paddle Inference 高性能推理

- Paddle Inference提供DIT模型高性能推理实现,推理性能提升80%+
推理步骤如下:
```shell
# 安装develop版本的paddle
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/
# 安装 triton
python -m pip install triton

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个会安装上torch和一系列依赖吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个会安装上torch和一系列依赖吧
已经修改为更全面的介绍文档,给出paddle适配triton的方法。 辛苦!

```
一键推理指令:
```shell
python ppdiffusers/examples/inference/class_conditional_image_generation-dit.py --inference_optimize 1
```

- 在 NVIDIA A100-SXM4-40GB 上测试的性能如下:

| Paddle Inference| TensorRT-LLM | Paddle动态图 |
| --------------- | ------------ | ------------ |
| 219 ms | 242 ms | 1200 ms |




## 引用
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,24 @@ def parse_args():
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
set_seed(42)

if args.inference_optimize:
# optimize the transformer using paddle.incubate.jit.inference
pipe.transformer = paddle.incubate.jit.inference(
pipe.transformer,
enable_new_ir=True,
save_model_dir="./tmp/dit",
cache_static_model=True,
exp_enable_use_cutlass=True,
delete_pass_lists=["add_norm_fuse_pass"],
)
pipe.vae.decode = paddle.incubate.jit.inference(
pipe.vae.decode,
enable_new_ir=True,
save_model_dir="./tmp/dit/vae",
cache_static_model=True,
exp_enable_use_cutlass=True,
)

words = ["golden retriever"] # class_ids [207]
class_ids = pipe.get_label_ids(words)
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]
Expand All @@ -71,15 +89,15 @@ def parse_args():

repeat_times = 5

paddle.device.synchronize()
starttime = datetime.datetime.now()
for i in range(repeat_times):
paddle.device.synchronize()
starttime = datetime.datetime.now()
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]
paddle.device.synchronize()
endtime = datetime.datetime.now()
paddle.device.synchronize()
endtime = datetime.datetime.now()

duringtime = endtime - starttime
time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0
print("The ave end to end time : ", time_ms / repeat_times, "ms")
duringtime = endtime - starttime
time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0
print("The this end to end time : ", time_ms, "ms")

image.save("class_conditional_image_generation-dit-result.png")
42 changes: 19 additions & 23 deletions ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import math
import os

import paddle
import paddle.nn.functional as F
Expand Down Expand Up @@ -84,6 +83,10 @@ def forward(self, hidden_states, timesteps, class_labels):
emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1)
common_emb = emb.cast(hidden_states.dtype)

last_ffn_output = None
last_hidden_states = None
last_gate_mlp = None

for i in range(self.num_layers):
emb = self.fcs0[i](common_emb)
emb = F.silu(emb)
Expand All @@ -94,44 +97,37 @@ def forward(self, hidden_states, timesteps, class_labels):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1)
import paddlemix

if os.getenv("INFERENCE_OPTIMIZE_TRITON"):
if last_ffn_output is None:
norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(
hidden_states, scale_msa, shift_msa, epsilon=1e-06
)
else:
norm_hidden_states = self.norm(
hidden_states,
hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual(
last_hidden_states, last_ffn_output, last_gate_mlp, scale_msa, shift_msa, epsilon=1e-06
)
norm_hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]

q = self.q[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim])
k = self.k[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim])
v = self.v[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim])

norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.head_dim**-0.5)
norm_hidden_states = norm_hidden_states.reshape(
[norm_hidden_states.shape[0], norm_hidden_states.shape[1], self.dim]
)
norm_hidden_states = norm_hidden_states.reshape([0, 0, self.dim])
norm_hidden_states = self.out_proj[i](norm_hidden_states)
if os.getenv("INFERENCE_OPTIMIZE_TRITON"):
hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual(
hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp, epsilon=1e-05
)
else:
hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape(
[norm_hidden_states.shape[0], 1, self.dim]
)
norm_hidden_states = self.norm1(
hidden_states,
)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual(
hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp, epsilon=1e-05
)

norm_hidden_states = self.ffn1[i](norm_hidden_states)
norm_hidden_states = F.gelu(norm_hidden_states, approximate=True)
norm_hidden_states = self.ffn2[i](norm_hidden_states)

hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape(
[norm_hidden_states.shape[0], 1, self.dim]
)
last_ffn_output = norm_hidden_states
last_hidden_states = hidden_states
last_gate_mlp = gate_mlp

hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape(
[norm_hidden_states.shape[0], 1, self.dim]
)

return hidden_states
10 changes: 2 additions & 8 deletions ppdiffusers/ppdiffusers/models/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,6 @@ def __init__(
self.simplified_facebookdit = SimplifiedFacebookDIT(
num_layers, inner_dim, num_attention_heads, attention_head_dim
)
self.simplified_facebookdit = paddle.incubate.jit.inference(
self.simplified_facebookdit,
enable_new_ir=True,
cache_static_model=False,
exp_enable_use_cutlass=True,
delete_pass_lists=["add_norm_fuse_pass"],
)

# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
Expand Down Expand Up @@ -498,7 +491,8 @@ def custom_forward(*inputs):
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = paddle.einsum("nhwpqc->nchpwq", hidden_states)
# hidden_states = paddle.einsum("nhwpqc->nchpwq", hidden_states)
hidden_states = hidden_states.transpose([0, 5, 1, 3, 2, 4])
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
Expand Down
2 changes: 1 addition & 1 deletion ppdiffusers/ppdiffusers/patches/paddle_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def scaled_dot_product_attention_(
pre_cache_length=0,
).transpose([0, 2, 1, 3])
elif attention_op == "flash":
with requires_grad_and_without_random(query, key, value):
with requires_grad_and_without_random(query, key, value, stop_gradient=False):
output = paddle.nn.functional.scaled_dot_product_attention(
query,
key,
Expand Down
17 changes: 13 additions & 4 deletions ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,12 @@ def __call__(
]
)
# predict noise model_output
noise_pred = self.transformer(
latent_model_input, timestep=timesteps, class_labels=class_labels_input
).sample
noise_pred_out = self.transformer(latent_model_input, timestep=timesteps, class_labels=class_labels_input)
if paddle.incubate.jit.is_inference_mode(self.transformer):
# self.transformer run in paddle inference.
noise_pred = noise_pred_out
else:
noise_pred = noise_pred_out.sample

# perform guidance
if guidance_scale > 1:
Expand Down Expand Up @@ -222,7 +225,13 @@ def __call__(
latents = latent_model_input

latents = 1 / self.vae.config.scaling_factor * latents
samples = self.vae.decode(latents).sample

samples_out = self.vae.decode(latents)
if paddle.incubate.jit.is_inference_mode(self.vae.decode):
# self.vae.decode run in paddle inference.
samples = samples_out
else:
samples = samples_out.sample

samples = (samples / 2 + 0.5).clip(0, 1)

Expand Down