From e42a8161d725af36eaa0e248496c76b085b689c8 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 29 Aug 2024 03:12:20 +0000 Subject: [PATCH 01/13] update DIT doc --- .../DiT/README.md | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md index abb3347cc..568a78010 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md @@ -216,6 +216,47 @@ image.save("result_DiT_golden_retriever.png") ``` +### 2.3 Paddle Inference 高性能推理 + + +- Paddle Inference加速DIT推理 + + +```shell +# 安装develop版本的paddle +python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ +# 安装 triton +python -m pip install triton +``` + +执行高性能推理的命令是: + +```shell +cd ppdiffusers/examples/inference +python class_conditional_image_generation-dit.py --inference_optimize 1 +``` + +- 在 NVIDIA A100-SXM4-40GB 上测试的性能如下: + +| Paddle Inference| TensorRT-LLM | +| --------------- | ------------ | +| 219 ms | 242 ms | + + + + + + + + + + + + + + + + ## 引用 ``` @article{Peebles2022DiT, From 591c7dd2cf4a4e58e633b4ec30514c5850d410f4 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 29 Aug 2024 03:21:59 +0000 Subject: [PATCH 02/13] update DIT doc --- .../class_conditional_image_generation/DiT/README.md | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md index 568a78010..447dc2726 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md @@ -215,22 +215,17 @@ image = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator image.save("result_DiT_golden_retriever.png") ``` - ### 2.3 Paddle Inference 高性能推理 - -- Paddle Inference加速DIT推理 - - +- Paddle Inference提供DIT模型高性能推理实现,推理性能提升80%+ +推理步骤如下: ```shell # 安装develop版本的paddle python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ # 安装 triton python -m pip install triton ``` - -执行高性能推理的命令是: - +一键推理指令: ```shell cd ppdiffusers/examples/inference python class_conditional_image_generation-dit.py --inference_optimize 1 From 678db6ea1eb75633bb71f759e5908fdfc870c205 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 29 Aug 2024 03:55:53 +0000 Subject: [PATCH 03/13] update dit --- .../class_conditional_image_generation/DiT/README.md | 6 +++--- .../class_conditional_image_generation-dit.py | 11 +++++++++++ ppdiffusers/ppdiffusers/models/transformer_2d.py | 10 ++-------- ppdiffusers/ppdiffusers/patches/paddle_patch.py | 2 +- ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py | 9 ++++++--- 5 files changed, 23 insertions(+), 15 deletions(-) diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md index 447dc2726..296dd0599 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md @@ -233,9 +233,9 @@ python class_conditional_image_generation-dit.py --inference_optimize 1 - 在 NVIDIA A100-SXM4-40GB 上测试的性能如下: -| Paddle Inference| TensorRT-LLM | -| --------------- | ------------ | -| 219 ms | 242 ms | +| Paddle Inference| TensorRT-LLM | Paddle | +| --------------- | ------------ | ------- | +| 219 ms | 242 ms | 1200 ms | diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index 84f6c7d8f..2b02c2ad5 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -59,6 +59,17 @@ def parse_args(): pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) set_seed(42) +if args.inference_optimize: + # optimize the transformer using paddle.incubate.jit.inference + pipe.transformer = paddle.incubate.jit.inference( + pipe.transformer, + enable_new_ir=True, + save_model_dir="./tmp/dit", + cache_static_model=True, + exp_enable_use_cutlass=True, + delete_pass_lists=["add_norm_fuse_pass"], + ) + words = ["golden retriever"] # class_ids [207] class_ids = pipe.get_label_ids(words) image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index cfd638460..3a4e7af01 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -225,13 +225,6 @@ def __init__( self.simplified_facebookdit = SimplifiedFacebookDIT( num_layers, inner_dim, num_attention_heads, attention_head_dim ) - self.simplified_facebookdit = paddle.incubate.jit.inference( - self.simplified_facebookdit, - enable_new_ir=True, - cache_static_model=False, - exp_enable_use_cutlass=True, - delete_pass_lists=["add_norm_fuse_pass"], - ) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels @@ -498,7 +491,8 @@ def custom_forward(*inputs): hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) - hidden_states = paddle.einsum("nhwpqc->nchpwq", hidden_states) + # hidden_states = paddle.einsum("nhwpqc->nchpwq", hidden_states) + hidden_states = hidden_states.transpose([0, 5, 1, 3, 2, 4]) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) diff --git a/ppdiffusers/ppdiffusers/patches/paddle_patch.py b/ppdiffusers/ppdiffusers/patches/paddle_patch.py index 8a73e4315..6845ad632 100644 --- a/ppdiffusers/ppdiffusers/patches/paddle_patch.py +++ b/ppdiffusers/ppdiffusers/patches/paddle_patch.py @@ -463,7 +463,7 @@ def scaled_dot_product_attention_( pre_cache_length=0, ).transpose([0, 2, 1, 3]) elif attention_op == "flash": - with requires_grad_and_without_random(query, key, value): + with requires_grad_and_without_random(query, key, value, stop_gradient=False): output = paddle.nn.functional.scaled_dot_product_attention( query, key, diff --git a/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py b/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py index 8654bbc2e..9586df292 100644 --- a/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py +++ b/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py @@ -191,9 +191,12 @@ def __call__( ] ) # predict noise model_output - noise_pred = self.transformer( - latent_model_input, timestep=timesteps, class_labels=class_labels_input - ).sample + noise_pred_out = self.transformer(latent_model_input, timestep=timesteps, class_labels=class_labels_input) + if paddle.incubate.jit.is_inference_mode(self.transformer): + # self.transformer run in paddle inference. + noise_pred = noise_pred_out + else: + noise_pred = noise_pred_out.sample # perform guidance if guidance_scale > 1: From d4aadb8d4acfcab54f5effd32189ca0a45bd637e Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 29 Aug 2024 05:27:12 +0000 Subject: [PATCH 04/13] uodate vae d2s --- .../class_conditional_image_generation-dit.py | 8 ++++++++ .../ppdiffusers/pipelines/dit/pipeline_dit.py | 19 ++++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index 2b02c2ad5..f9d0db670 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -69,6 +69,14 @@ def parse_args(): exp_enable_use_cutlass=True, delete_pass_lists=["add_norm_fuse_pass"], ) + pipe.vae.decode = paddle.incubate.jit.inference( + pipe.vae.decode, + enable_new_ir=True, + save_model_dir="./tmp/dit/vae", + cache_static_model=True, + exp_enable_use_cutlass=True, + delete_pass_lists=["add_norm_fuse_pass"], + ) words = ["golden retriever"] # class_ids [207] class_ids = pipe.get_label_ids(words) diff --git a/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py b/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py index 9586df292..26cd33284 100644 --- a/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py +++ b/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py @@ -225,7 +225,24 @@ def __call__( latents = latent_model_input latents = 1 / self.vae.config.scaling_factor * latents - samples = self.vae.decode(latents).sample + + import datetime + + paddle.device.synchronize() + starttime = datetime.datetime.now() + + samples_out = self.vae.decode(latents) + if paddle.incubate.jit.is_inference_mode(self.vae.decode): + # self.vae.decode run in paddle inference. + samples = samples_out + else: + samples = samples_out.sample + + paddle.device.synchronize() + endtime = datetime.datetime.now() + duringtime = endtime - starttime + time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 + print("The VAE decode time : ", time_ms, "ms") samples = (samples / 2 + 0.5).clip(0, 1) From 0317bfbc91b4aba7b81eb63c0e4a42294adc5057 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 29 Aug 2024 05:29:41 +0000 Subject: [PATCH 05/13] uodate vae d2s --- .../ppdiffusers/pipelines/dit/pipeline_dit.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py b/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py index 26cd33284..f61ccf07d 100644 --- a/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py +++ b/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py @@ -226,10 +226,10 @@ def __call__( latents = 1 / self.vae.config.scaling_factor * latents - import datetime + # import datetime - paddle.device.synchronize() - starttime = datetime.datetime.now() + # paddle.device.synchronize() + # starttime = datetime.datetime.now() samples_out = self.vae.decode(latents) if paddle.incubate.jit.is_inference_mode(self.vae.decode): @@ -238,11 +238,11 @@ def __call__( else: samples = samples_out.sample - paddle.device.synchronize() - endtime = datetime.datetime.now() - duringtime = endtime - starttime - time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 - print("The VAE decode time : ", time_ms, "ms") + # paddle.device.synchronize() + # endtime = datetime.datetime.now() + # duringtime = endtime - starttime + # time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 + # print("The VAE decode time : ", time_ms, "ms") samples = (samples / 2 + 0.5).clip(0, 1) From 548d3cc7fb702a5ae8451f0a6c14cbecf511a69e Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 29 Aug 2024 07:19:45 +0000 Subject: [PATCH 06/13] update DIT inference_optimize --- paddlemix/triton_ops/__init__.py | 4 + paddlemix/triton_ops/triton_ops.py | 312 +++++++++++++++++- .../class_conditional_image_generation-dit.py | 15 +- .../models/simplified_facebook_dit.py | 38 +-- 4 files changed, 336 insertions(+), 33 deletions(-) diff --git a/paddlemix/triton_ops/__init__.py b/paddlemix/triton_ops/__init__.py index 76db91ab2..f10d9daaf 100644 --- a/paddlemix/triton_ops/__init__.py +++ b/paddlemix/triton_ops/__init__.py @@ -21,6 +21,8 @@ fused_rotary_emb, paddle_use_triton, rms_norm, + split_concat, + triton_split, weight_only_int8, ) from .triton_utils import ( @@ -39,6 +41,8 @@ "rms_norm", "get_dtype_str", "fused_rotary_emb", + "split_concat", + "triton_split", ] except: pass diff --git a/paddlemix/triton_ops/triton_ops.py b/paddlemix/triton_ops/triton_ops.py index 3ade229c3..070d8302b 100644 --- a/paddlemix/triton_ops/triton_ops.py +++ b/paddlemix/triton_ops/triton_ops.py @@ -839,6 +839,13 @@ def paddle_fused_adaLN(x, mha_out, gate, hidd, scale, shift, weight, bias, epsil seq_size = x.shape[1] N_npo2 = triton.next_power_of_2(N) + # baseline. + if os.getenv("INFERENCE_OPTIMIZE_TRITON") is None: + resi_out_paddle = mha_out * gate_msa.unsqueeze(axis=1) + x + norm_hidden_states = paddle.nn.functional.layer_norm(resi_out_paddle, [N], weight, bias, epsilon) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + return resi_out_paddle, norm_hidden_states + op_name = "triton_fused_adaLN_scale_residual" op_name += get_dtype_str(x.dtype) op_name += f"_{N_npo2}_{weight_attr}_{bias_attr}" @@ -865,9 +872,9 @@ def paddle_fused_adaLN(x, mha_out, gate, hidd, scale, shift, weight, bias, epsil shift_mlp, resi_out, adaLN_out, - M, + -1, N, - seq_size, + -1, epsilon, N_npo2=N_npo2, weight_attr=weight_attr, @@ -1072,7 +1079,13 @@ def modulate(x, shift, scale): M = x.shape[0] * x.shape[1] N = x.shape[2] seq_size = x.shape[1] - BLOCK_SIZE = min(1024, triton.next_power_of_2(N)) + BLOCK_SIZE = 2048 # min(1024, triton.next_power_of_2(N)) + + # baseline. + if os.getenv("INFERENCE_OPTIMIZE_TRITON") is None: + norm_hidden_states = paddle.nn.functional.layer_norm(x, [N], weight, bias, epsilon) + norm_hidden_states = norm_hidden_states * (1 + scale[:, None]) + shift[:, None] + return norm_hidden_states op_name = "triton_adaptive_layer_norm" op_name += get_dtype_str(x.dtype) @@ -1096,9 +1109,9 @@ def modulate(x, shift, scale): y, y, y, - M, + -1, N, - seq_size, + -1, epsilon, BLOCK_SIZE=BLOCK_SIZE, weight_attr=weight_attr, @@ -1567,3 +1580,292 @@ def fused_rotary_emb( outputs={"q_out": q_out, "k_out": k_out, "v_out": v_out}, ) return q_out, k_out, v_out + + +########################### split concat ############################### +split_concat_template = ( + """ +std::vector ${op_name}_func( + const paddle::Tensor &x, + const paddle::Tensor &y) { + + int batch = x.dims()[0]; + + int seq_qkv = x.dims()[1]; + int seq_eqkv = y.dims()[1]; + int output_hidden = x.dims()[2] / 3; + + + auto qkv = get_tensor_ptr(x); + auto eqkv = get_tensor_ptr(y); + + + auto out0_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place()); + auto out1_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place()); + auto out2_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place()); + + auto out0 = get_tensor_ptr(out0_tensor); + auto out1 = get_tensor_ptr(out1_tensor); + auto out2 = get_tensor_ptr(out2_tensor); + + + auto run_stream = out0_tensor.stream(); + +""" + + tune_and_invoke_part + + """ + return {out0_tensor, out1_tensor, out2_tensor}; +} + +std::vector> ${op_name}_InferShape( + const std::vector& A_shape, const std::vector& B_shape) { + + std::vector out_shape = {A_shape[0], A_shape[1]+B_shape[1], A_shape[2]/3}; + + return {out_shape, out_shape, out_shape}; +} + +std::vector ${op_name}_InferDtype(const paddle::DataType& A_dtype) { + return {A_dtype, A_dtype, A_dtype}; +} + +PD_BUILD_OP(${op_name}) + .Inputs({"x", "y"}) + .Outputs({"out0_tensor", "out1_tensor", "out2_tensor"}) + .SetKernelFn(PD_KERNEL(${op_name}_func)) + .SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype)) + .SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape)); +""" +) + + +@paddle_use_triton( + custom_op_template=split_concat_template, + key=["1"], +) +def split_concat_kernel( + out0, + out1, + out2, + qkv, + eqkv, + batch, + seq_qkv, + seq_eqkv, + output_hidden, + BLOCK_SIZE: tl.constexpr, +): + out_id = tl.program_id(axis=0) + batch = tl.program_id(axis=1) + out_row = tl.program_id(axis=2) + if out_row < seq_qkv: + read_ptr = out_id * output_hidden + out_row * 3 * output_hidden + batch * seq_qkv * output_hidden * 3 + qkv + else: + read_ptr = ( + out_id * output_hidden + + (out_row - seq_qkv) * 3 * output_hidden + + batch * seq_eqkv * output_hidden * 3 + + eqkv + ) + + read_offsets = tl.arange(0, BLOCK_SIZE) + mask = read_offsets < output_hidden + read_data = tl.load(read_ptr + read_offsets, mask=mask) + + real_output = out0 + if out_id == 1: + real_output = out1 + elif out_id == 2: + real_output = out2 + + write_ptr = batch * (seq_qkv + seq_eqkv) * output_hidden + out_row * output_hidden + real_output + read_offsets + + tl.store(write_ptr, read_data, mask=mask) + + +def split_concat(x, y): + assert len(x.shape) == 3 + assert len(y.shape) == 3 + + assert x.shape[0] == y.shape[0] + assert x.shape[2] == y.shape[2] + + batch = x.shape[0] + seq_qkv = x.shape[1] + hidd_x = x.shape[2] + seq_eqkv = y.shape[1] + ouput_hidden = hidd_x // 3 + BLOCK_SIZE = triton.next_power_of_2(ouput_hidden) + op_name = "split_concat" + op_name += get_dtype_str(x.dtype) + op_name += f"_{BLOCK_SIZE}" + + if op_name not in OpProtoHolder.instance().op_proto_map.keys(): + out0 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype) + out1 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype) + out2 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype) + grid = ("3", "batch", "seq_qkv + seq_eqkv") + + split_concat_kernel[(op_name, grid)]( + out0, out1, out2, x, y, batch, seq_qkv, seq_eqkv, ouput_hidden, BLOCK_SIZE=2048 + ) + + if in_dynamic_or_pir_mode(): + print(f"== we are in dynamic mode, op_name: {op_name}") + outs = _C_ops._run_custom_op( + op_name, + x, + y, + ) + return outs[0], outs[1], outs[2] + else: + print(f"== we are in dynamic to static mode, op_name: {op_name}") + helper = LayerHelper(op_name, **locals()) + inputs = { + "x": x, + "y": y, + } + out0 = helper.create_variable_for_type_inference(dtype=x.dtype) + out1 = helper.create_variable_for_type_inference(dtype=x.dtype) + out2 = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type=op_name, + inputs=inputs, + outputs={"out0_tensor": out0, "out1_tensor": out1, "out2_tensor": out2}, + ) + return out0, out1, out2 + + +########################### triton split ############################### +triton_split_template = ( + """ +std::vector ${op_name}_func( + const paddle::Tensor &x, + const std::vector num_or_sections, + const int64_t axis) { + + int output_batch = x.dims()[0]; + int output_seq0 = num_or_sections[0]; + int output_seq1 = num_or_sections[1]; + int output_hidden = x.dims()[2]; + + auto out0_tensor = paddle::empty({output_batch, output_seq0, output_hidden}, x.dtype(), x.place()); + auto out1_tensor = paddle::empty({output_batch, output_seq1, output_hidden}, x.dtype(), x.place()); + + auto out0 = get_tensor_ptr(out0_tensor); + auto out1 = get_tensor_ptr(out1_tensor); + + auto input = get_tensor_ptr(x); + + auto run_stream = out0_tensor.stream(); + +""" + + tune_and_invoke_part + + """ + return {out0_tensor, out1_tensor}; +} + +std::vector> ${op_name}_InferShape( + const std::vector& A_shape) { + + std::vector out_shape0 = {A_shape[0], 1024, A_shape[2]}; + std::vector out_shape1 = {A_shape[0], 154, A_shape[2]}; + + return {out_shape0, out_shape1}; +} + +std::vector ${op_name}_InferDtype(const paddle::DataType& A_dtype) { + return {A_dtype, A_dtype}; +} + +PD_BUILD_OP(${op_name}) + .Inputs({"x"}) + .Outputs({"out0_tensor", "out1_tensor"}) + .SetKernelFn(PD_KERNEL(${op_name}_func)) + .Attrs({"num_or_sections: std::vector", "axis: int64_t"}) + .SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype)) + .SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape)); +""" +) + + +@paddle_use_triton( + custom_op_template=triton_split_template, + key=["1"], +) +def triton_split_kernel( + out0, + out1, + input, + output_seq0, + output_seq1, + output_batch, + output_hidden, + BLOCK_SIZE: tl.constexpr, +): + batch = tl.program_id(axis=0) + out_row = tl.program_id(axis=1) + read_ptr = out_row * output_hidden + batch * (output_seq0 + output_seq1) * output_hidden + input + + read_offsets = tl.arange(0, BLOCK_SIZE) + mask = read_offsets < output_hidden + read_data = tl.load(read_ptr + read_offsets, mask=mask) + + if out_row < output_seq0: + write_ptr = batch * output_seq0 * output_hidden + out_row * output_hidden + out0 + read_offsets + else: + write_ptr = batch * output_seq1 * output_hidden + (out_row - output_seq0) * output_hidden + out1 + read_offsets + + tl.store(write_ptr, read_data, mask=mask) + + +def triton_split(x, num_or_sections=[-1, -1], axis=1): + assert len(x.shape) == 3 + output_batch = x.shape[0] + output_seq0 = num_or_sections[0] + output_seq1 = num_or_sections[1] + output_hidden = x.shape[2] + + BLOCK_SIZE = triton.next_power_of_2(output_hidden) + op_name = "triton_split" + op_name += get_dtype_str(x.dtype) + op_name += f"_{BLOCK_SIZE}" + + if op_name not in OpProtoHolder.instance().op_proto_map.keys(): + out0 = paddle.empty(shape=[output_batch, output_seq0, output_hidden], dtype=x.dtype) + out1 = paddle.empty(shape=[output_batch, output_seq1, output_hidden], dtype=x.dtype) + grid = ("output_batch", "output_seq0+output_seq1") + + triton_split_kernel[(op_name, grid)]( + out0, out1, x, output_seq0, output_seq1, output_batch, output_hidden, BLOCK_SIZE=2048 + ) + + if in_dynamic_or_pir_mode(): + print(f"== we are in dynamic mode, op_name: {op_name}") + outs = _C_ops._run_custom_op( + op_name, + x, + num_or_sections, + axis, + ) + return outs[0], outs[1] + else: + print(f"== we are in dynamic to static mode, op_name: {op_name}") + helper = LayerHelper(op_name, **locals()) + inputs = { + "x": x, + } + out0 = helper.create_variable_for_type_inference(dtype=x.dtype) + out1 = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type=op_name, + inputs=inputs, + attrs={ + "num_or_sections": num_or_sections, + "axis": axis, + }, + outputs={"out0_tensor": out0, "out1_tensor": out1}, + ) + return out0, out1 diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index f9d0db670..04c9db51e 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -75,7 +75,6 @@ def parse_args(): save_model_dir="./tmp/dit/vae", cache_static_model=True, exp_enable_use_cutlass=True, - delete_pass_lists=["add_norm_fuse_pass"], ) words = ["golden retriever"] # class_ids [207] @@ -90,15 +89,15 @@ def parse_args(): repeat_times = 5 - paddle.device.synchronize() - starttime = datetime.datetime.now() for i in range(repeat_times): + paddle.device.synchronize() + starttime = datetime.datetime.now() image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] - paddle.device.synchronize() - endtime = datetime.datetime.now() + paddle.device.synchronize() + endtime = datetime.datetime.now() - duringtime = endtime - starttime - time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 - print("The ave end to end time : ", time_ms / repeat_times, "ms") + duringtime = endtime - starttime + time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 + print("The this end to end time : ", time_ms, "ms") image.save("class_conditional_image_generation-dit-result.png") diff --git a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py index 2a1fde485..cc5e4e603 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py @@ -13,7 +13,6 @@ # limitations under the License. import math -import os import paddle import paddle.nn.functional as F @@ -84,6 +83,10 @@ def forward(self, hidden_states, timesteps, class_labels): emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1) common_emb = emb.cast(hidden_states.dtype) + last_ffn_output = None + last_hidden_states = None + last_gate_mlp = None + for i in range(self.num_layers): emb = self.fcs0[i](common_emb) emb = F.silu(emb) @@ -94,15 +97,14 @@ def forward(self, hidden_states, timesteps, class_labels): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) import paddlemix - if os.getenv("INFERENCE_OPTIMIZE_TRITON"): + if last_ffn_output is None: norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( hidden_states, scale_msa, shift_msa, epsilon=1e-06 ) else: - norm_hidden_states = self.norm( - hidden_states, + hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( + last_hidden_states, last_ffn_output, last_gate_mlp, scale_msa, shift_msa, epsilon=1e-06 ) - norm_hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None] q = self.q[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim]) k = self.k[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim]) @@ -113,25 +115,21 @@ def forward(self, hidden_states, timesteps, class_labels): [norm_hidden_states.shape[0], norm_hidden_states.shape[1], self.dim] ) norm_hidden_states = self.out_proj[i](norm_hidden_states) - if os.getenv("INFERENCE_OPTIMIZE_TRITON"): - hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( - hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp, epsilon=1e-05 - ) - else: - hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape( - [norm_hidden_states.shape[0], 1, self.dim] - ) - norm_hidden_states = self.norm1( - hidden_states, - ) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( + hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp, epsilon=1e-05 + ) norm_hidden_states = self.ffn1[i](norm_hidden_states) norm_hidden_states = F.gelu(norm_hidden_states, approximate=True) norm_hidden_states = self.ffn2[i](norm_hidden_states) - hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape( - [norm_hidden_states.shape[0], 1, self.dim] - ) + last_ffn_output = norm_hidden_states + last_hidden_states = hidden_states + last_gate_mlp = gate_mlp + + hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape( + [norm_hidden_states.shape[0], 1, self.dim] + ) return hidden_states From 4e4f375a38ccf8b651bb1b0d6fb7f8c472f62c2c Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 29 Aug 2024 07:50:21 +0000 Subject: [PATCH 07/13] update DIT inference_optimize --- .../class_conditional_image_generation/DiT/README.md | 12 ------------ .../ppdiffusers/models/simplified_facebook_dit.py | 4 +--- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md index 296dd0599..a7e8342e7 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md @@ -240,18 +240,6 @@ python class_conditional_image_generation-dit.py --inference_optimize 1 - - - - - - - - - - - - ## 引用 ``` @article{Peebles2022DiT, diff --git a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py index cc5e4e603..6014e4d54 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py @@ -111,9 +111,7 @@ def forward(self, hidden_states, timesteps, class_labels): v = self.v[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim]) norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.head_dim**-0.5) - norm_hidden_states = norm_hidden_states.reshape( - [norm_hidden_states.shape[0], norm_hidden_states.shape[1], self.dim] - ) + norm_hidden_states = norm_hidden_states.reshape([0, 0, self.dim]) norm_hidden_states = self.out_proj[i](norm_hidden_states) hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( From db1373225a75e9f263d74a3c2ed90205778ec5c6 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 29 Aug 2024 07:54:48 +0000 Subject: [PATCH 08/13] update DIT inference_optimize --- paddlemix/triton_ops/__init__.py | 4 - paddlemix/triton_ops/triton_ops.py | 289 ------------------ .../ppdiffusers/pipelines/dit/pipeline_dit.py | 11 - 3 files changed, 304 deletions(-) diff --git a/paddlemix/triton_ops/__init__.py b/paddlemix/triton_ops/__init__.py index f10d9daaf..76db91ab2 100644 --- a/paddlemix/triton_ops/__init__.py +++ b/paddlemix/triton_ops/__init__.py @@ -21,8 +21,6 @@ fused_rotary_emb, paddle_use_triton, rms_norm, - split_concat, - triton_split, weight_only_int8, ) from .triton_utils import ( @@ -41,8 +39,6 @@ "rms_norm", "get_dtype_str", "fused_rotary_emb", - "split_concat", - "triton_split", ] except: pass diff --git a/paddlemix/triton_ops/triton_ops.py b/paddlemix/triton_ops/triton_ops.py index 070d8302b..7561f6bcb 100644 --- a/paddlemix/triton_ops/triton_ops.py +++ b/paddlemix/triton_ops/triton_ops.py @@ -1580,292 +1580,3 @@ def fused_rotary_emb( outputs={"q_out": q_out, "k_out": k_out, "v_out": v_out}, ) return q_out, k_out, v_out - - -########################### split concat ############################### -split_concat_template = ( - """ -std::vector ${op_name}_func( - const paddle::Tensor &x, - const paddle::Tensor &y) { - - int batch = x.dims()[0]; - - int seq_qkv = x.dims()[1]; - int seq_eqkv = y.dims()[1]; - int output_hidden = x.dims()[2] / 3; - - - auto qkv = get_tensor_ptr(x); - auto eqkv = get_tensor_ptr(y); - - - auto out0_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place()); - auto out1_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place()); - auto out2_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place()); - - auto out0 = get_tensor_ptr(out0_tensor); - auto out1 = get_tensor_ptr(out1_tensor); - auto out2 = get_tensor_ptr(out2_tensor); - - - auto run_stream = out0_tensor.stream(); - -""" - + tune_and_invoke_part - + """ - return {out0_tensor, out1_tensor, out2_tensor}; -} - -std::vector> ${op_name}_InferShape( - const std::vector& A_shape, const std::vector& B_shape) { - - std::vector out_shape = {A_shape[0], A_shape[1]+B_shape[1], A_shape[2]/3}; - - return {out_shape, out_shape, out_shape}; -} - -std::vector ${op_name}_InferDtype(const paddle::DataType& A_dtype) { - return {A_dtype, A_dtype, A_dtype}; -} - -PD_BUILD_OP(${op_name}) - .Inputs({"x", "y"}) - .Outputs({"out0_tensor", "out1_tensor", "out2_tensor"}) - .SetKernelFn(PD_KERNEL(${op_name}_func)) - .SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype)) - .SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape)); -""" -) - - -@paddle_use_triton( - custom_op_template=split_concat_template, - key=["1"], -) -def split_concat_kernel( - out0, - out1, - out2, - qkv, - eqkv, - batch, - seq_qkv, - seq_eqkv, - output_hidden, - BLOCK_SIZE: tl.constexpr, -): - out_id = tl.program_id(axis=0) - batch = tl.program_id(axis=1) - out_row = tl.program_id(axis=2) - if out_row < seq_qkv: - read_ptr = out_id * output_hidden + out_row * 3 * output_hidden + batch * seq_qkv * output_hidden * 3 + qkv - else: - read_ptr = ( - out_id * output_hidden - + (out_row - seq_qkv) * 3 * output_hidden - + batch * seq_eqkv * output_hidden * 3 - + eqkv - ) - - read_offsets = tl.arange(0, BLOCK_SIZE) - mask = read_offsets < output_hidden - read_data = tl.load(read_ptr + read_offsets, mask=mask) - - real_output = out0 - if out_id == 1: - real_output = out1 - elif out_id == 2: - real_output = out2 - - write_ptr = batch * (seq_qkv + seq_eqkv) * output_hidden + out_row * output_hidden + real_output + read_offsets - - tl.store(write_ptr, read_data, mask=mask) - - -def split_concat(x, y): - assert len(x.shape) == 3 - assert len(y.shape) == 3 - - assert x.shape[0] == y.shape[0] - assert x.shape[2] == y.shape[2] - - batch = x.shape[0] - seq_qkv = x.shape[1] - hidd_x = x.shape[2] - seq_eqkv = y.shape[1] - ouput_hidden = hidd_x // 3 - BLOCK_SIZE = triton.next_power_of_2(ouput_hidden) - op_name = "split_concat" - op_name += get_dtype_str(x.dtype) - op_name += f"_{BLOCK_SIZE}" - - if op_name not in OpProtoHolder.instance().op_proto_map.keys(): - out0 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype) - out1 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype) - out2 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype) - grid = ("3", "batch", "seq_qkv + seq_eqkv") - - split_concat_kernel[(op_name, grid)]( - out0, out1, out2, x, y, batch, seq_qkv, seq_eqkv, ouput_hidden, BLOCK_SIZE=2048 - ) - - if in_dynamic_or_pir_mode(): - print(f"== we are in dynamic mode, op_name: {op_name}") - outs = _C_ops._run_custom_op( - op_name, - x, - y, - ) - return outs[0], outs[1], outs[2] - else: - print(f"== we are in dynamic to static mode, op_name: {op_name}") - helper = LayerHelper(op_name, **locals()) - inputs = { - "x": x, - "y": y, - } - out0 = helper.create_variable_for_type_inference(dtype=x.dtype) - out1 = helper.create_variable_for_type_inference(dtype=x.dtype) - out2 = helper.create_variable_for_type_inference(dtype=x.dtype) - - helper.append_op( - type=op_name, - inputs=inputs, - outputs={"out0_tensor": out0, "out1_tensor": out1, "out2_tensor": out2}, - ) - return out0, out1, out2 - - -########################### triton split ############################### -triton_split_template = ( - """ -std::vector ${op_name}_func( - const paddle::Tensor &x, - const std::vector num_or_sections, - const int64_t axis) { - - int output_batch = x.dims()[0]; - int output_seq0 = num_or_sections[0]; - int output_seq1 = num_or_sections[1]; - int output_hidden = x.dims()[2]; - - auto out0_tensor = paddle::empty({output_batch, output_seq0, output_hidden}, x.dtype(), x.place()); - auto out1_tensor = paddle::empty({output_batch, output_seq1, output_hidden}, x.dtype(), x.place()); - - auto out0 = get_tensor_ptr(out0_tensor); - auto out1 = get_tensor_ptr(out1_tensor); - - auto input = get_tensor_ptr(x); - - auto run_stream = out0_tensor.stream(); - -""" - + tune_and_invoke_part - + """ - return {out0_tensor, out1_tensor}; -} - -std::vector> ${op_name}_InferShape( - const std::vector& A_shape) { - - std::vector out_shape0 = {A_shape[0], 1024, A_shape[2]}; - std::vector out_shape1 = {A_shape[0], 154, A_shape[2]}; - - return {out_shape0, out_shape1}; -} - -std::vector ${op_name}_InferDtype(const paddle::DataType& A_dtype) { - return {A_dtype, A_dtype}; -} - -PD_BUILD_OP(${op_name}) - .Inputs({"x"}) - .Outputs({"out0_tensor", "out1_tensor"}) - .SetKernelFn(PD_KERNEL(${op_name}_func)) - .Attrs({"num_or_sections: std::vector", "axis: int64_t"}) - .SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype)) - .SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape)); -""" -) - - -@paddle_use_triton( - custom_op_template=triton_split_template, - key=["1"], -) -def triton_split_kernel( - out0, - out1, - input, - output_seq0, - output_seq1, - output_batch, - output_hidden, - BLOCK_SIZE: tl.constexpr, -): - batch = tl.program_id(axis=0) - out_row = tl.program_id(axis=1) - read_ptr = out_row * output_hidden + batch * (output_seq0 + output_seq1) * output_hidden + input - - read_offsets = tl.arange(0, BLOCK_SIZE) - mask = read_offsets < output_hidden - read_data = tl.load(read_ptr + read_offsets, mask=mask) - - if out_row < output_seq0: - write_ptr = batch * output_seq0 * output_hidden + out_row * output_hidden + out0 + read_offsets - else: - write_ptr = batch * output_seq1 * output_hidden + (out_row - output_seq0) * output_hidden + out1 + read_offsets - - tl.store(write_ptr, read_data, mask=mask) - - -def triton_split(x, num_or_sections=[-1, -1], axis=1): - assert len(x.shape) == 3 - output_batch = x.shape[0] - output_seq0 = num_or_sections[0] - output_seq1 = num_or_sections[1] - output_hidden = x.shape[2] - - BLOCK_SIZE = triton.next_power_of_2(output_hidden) - op_name = "triton_split" - op_name += get_dtype_str(x.dtype) - op_name += f"_{BLOCK_SIZE}" - - if op_name not in OpProtoHolder.instance().op_proto_map.keys(): - out0 = paddle.empty(shape=[output_batch, output_seq0, output_hidden], dtype=x.dtype) - out1 = paddle.empty(shape=[output_batch, output_seq1, output_hidden], dtype=x.dtype) - grid = ("output_batch", "output_seq0+output_seq1") - - triton_split_kernel[(op_name, grid)]( - out0, out1, x, output_seq0, output_seq1, output_batch, output_hidden, BLOCK_SIZE=2048 - ) - - if in_dynamic_or_pir_mode(): - print(f"== we are in dynamic mode, op_name: {op_name}") - outs = _C_ops._run_custom_op( - op_name, - x, - num_or_sections, - axis, - ) - return outs[0], outs[1] - else: - print(f"== we are in dynamic to static mode, op_name: {op_name}") - helper = LayerHelper(op_name, **locals()) - inputs = { - "x": x, - } - out0 = helper.create_variable_for_type_inference(dtype=x.dtype) - out1 = helper.create_variable_for_type_inference(dtype=x.dtype) - - helper.append_op( - type=op_name, - inputs=inputs, - attrs={ - "num_or_sections": num_or_sections, - "axis": axis, - }, - outputs={"out0_tensor": out0, "out1_tensor": out1}, - ) - return out0, out1 diff --git a/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py b/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py index f61ccf07d..732cbea3a 100644 --- a/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py +++ b/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py @@ -226,11 +226,6 @@ def __call__( latents = 1 / self.vae.config.scaling_factor * latents - # import datetime - - # paddle.device.synchronize() - # starttime = datetime.datetime.now() - samples_out = self.vae.decode(latents) if paddle.incubate.jit.is_inference_mode(self.vae.decode): # self.vae.decode run in paddle inference. @@ -238,12 +233,6 @@ def __call__( else: samples = samples_out.sample - # paddle.device.synchronize() - # endtime = datetime.datetime.now() - # duringtime = endtime - starttime - # time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 - # print("The VAE decode time : ", time_ms, "ms") - samples = (samples / 2 + 0.5).clip(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 From b51d757d906a584bbc2ef1cdbacb58be06d37b31 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 29 Aug 2024 07:56:55 +0000 Subject: [PATCH 09/13] update DIT inference_optimize --- .../examples/class_conditional_image_generation/DiT/README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md index a7e8342e7..fa6c080b1 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md @@ -227,8 +227,7 @@ python -m pip install triton ``` 一键推理指令: ```shell -cd ppdiffusers/examples/inference -python class_conditional_image_generation-dit.py --inference_optimize 1 +python ppdiffusers/examples/inference/class_conditional_image_generation-dit.py --inference_optimize 1 ``` - 在 NVIDIA A100-SXM4-40GB 上测试的性能如下: From 5978a810c882ff40ba93255afd1ae167d37f86da Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 29 Aug 2024 08:10:24 +0000 Subject: [PATCH 10/13] update DIT inference_optimize --- paddlemix/triton_ops/triton_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlemix/triton_ops/triton_ops.py b/paddlemix/triton_ops/triton_ops.py index 7561f6bcb..7ac8ad2bc 100644 --- a/paddlemix/triton_ops/triton_ops.py +++ b/paddlemix/triton_ops/triton_ops.py @@ -1079,7 +1079,7 @@ def modulate(x, shift, scale): M = x.shape[0] * x.shape[1] N = x.shape[2] seq_size = x.shape[1] - BLOCK_SIZE = 2048 # min(1024, triton.next_power_of_2(N)) + BLOCK_SIZE = triton.next_power_of_2(N) # baseline. if os.getenv("INFERENCE_OPTIMIZE_TRITON") is None: From eeb76a45bc9d02fe241475d76bc4869f905abb49 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 29 Aug 2024 08:38:31 +0000 Subject: [PATCH 11/13] update DIT dco --- .../class_conditional_image_generation/DiT/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md index fa6c080b1..a2fe5cf4c 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md @@ -232,9 +232,9 @@ python ppdiffusers/examples/inference/class_conditional_image_generation-dit.py - 在 NVIDIA A100-SXM4-40GB 上测试的性能如下: -| Paddle Inference| TensorRT-LLM | Paddle | -| --------------- | ------------ | ------- | -| 219 ms | 242 ms | 1200 ms | +| Paddle Inference| TensorRT-LLM | Paddle动态图 | +| --------------- | ------------ | ------------ | +| 219 ms | 242 ms | 1200 ms | From adb169a6e259e14f5d7cd7193d8ac51455146279 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 29 Aug 2024 11:42:40 +0000 Subject: [PATCH 12/13] update DIT --- .../class_conditional_image_generation/DiT/README.md | 8 ++++++-- .../class_conditional_image_generation-dit.py | 10 ++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md index a2fe5cf4c..efe80029b 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md @@ -222,9 +222,13 @@ image.save("result_DiT_golden_retriever.png") ```shell # 安装develop版本的paddle python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ -# 安装 triton + +# 安装 triton并适配paddle python -m pip install triton -``` +python -m pip install git+https://github.com/zhoutianzi666/UseTritonInPaddle.git +python -c "import use_triton_in_paddle; use_triton_in_paddle.make_triton_compatible_with_paddle()" +```` + 一键推理指令: ```shell python ppdiffusers/examples/inference/class_conditional_image_generation-dit.py --inference_optimize 1 diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index 04c9db51e..9ce17485d 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -57,7 +57,7 @@ def parse_args(): dtype = paddle.float16 pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) -set_seed(42) + if args.inference_optimize: # optimize the transformer using paddle.incubate.jit.inference @@ -76,7 +76,7 @@ def parse_args(): cache_static_model=True, exp_enable_use_cutlass=True, ) - +set_seed(42) words = ["golden retriever"] # class_ids [207] class_ids = pipe.get_label_ids(words) image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] @@ -84,7 +84,8 @@ def parse_args(): if args.benchmark: # warmup - for i in range(5): + for i in range(3): + set_seed(42) image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] repeat_times = 5 @@ -92,12 +93,13 @@ def parse_args(): for i in range(repeat_times): paddle.device.synchronize() starttime = datetime.datetime.now() + set_seed(42) image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] paddle.device.synchronize() endtime = datetime.datetime.now() duringtime = endtime - starttime time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 - print("The this end to end time : ", time_ms, "ms") + print("DIT end to end time : ", time_ms, "ms") image.save("class_conditional_image_generation-dit-result.png") From 4ec4c56fd8fa6e18b1bb94a81fbcc1f6f12db709 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 29 Aug 2024 11:50:44 +0000 Subject: [PATCH 13/13] update DIT --- .../examples/class_conditional_image_generation/DiT/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md index efe80029b..0b84027cc 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md @@ -218,7 +218,7 @@ image.save("result_DiT_golden_retriever.png") ### 2.3 Paddle Inference 高性能推理 - Paddle Inference提供DIT模型高性能推理实现,推理性能提升80%+ -推理步骤如下: +环境准备: ```shell # 安装develop版本的paddle python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/