From a6631e7d7293c3725b30099ed988dd3d9e426205 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Mon, 19 Aug 2024 05:37:30 +0000 Subject: [PATCH 01/65] optimize SD3 --- ..._to_image_generation-stable_diffusion_3.py | 87 +++++++- ppdiffusers/ppdiffusers/models/attention.py | 22 ++- .../ppdiffusers/models/attention_processor.py | 32 ++- .../ppdiffusers/models/normalization.py | 18 +- .../ppdiffusers/models/simplified_sd3.py | 187 ++++++++++++++++++ .../ppdiffusers/patches/paddle_patch.py | 2 +- 6 files changed, 329 insertions(+), 19 deletions(-) create mode 100644 ppdiffusers/ppdiffusers/models/simplified_sd3.py diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index e536d1705..411cb9f25 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -11,13 +11,96 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + +os.environ["FLAGS_use_cuda_managed_memory"] = "true" +import argparse +import datetime import paddle + from ppdiffusers import StableDiffusion3Pipeline + pipe = StableDiffusion3Pipeline.from_pretrained( - "stabilityai/stable-diffusion-3-medium-diffusers", paddle_dtype=paddle.float16 + "/root/.cache/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/stable-diffusion-3-medium-diffusers", + paddle_dtype=paddle.float16, + from_hf_hub=True, + from_diffusers=True, ) generator = paddle.Generator().manual_seed(42) prompt = "A cat holding a sign that says hello world" -image = pipe(prompt, generator=generator).images[0] + + +def parse_args(): + parser = argparse.ArgumentParser( + description=" Use PaddleMIX to accelerate the Stable Diffusion3 image generation model." + ) + parser.add_argument( + "--benchmark", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=False, + help="if benchmark is set to True, measure inference performance", + ) + parser.add_argument( + "--inference_optimize", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=False, + help="If inference_optimize is set to True, all optimizations except Triton are enabled.", + ) + parser.add_argument( + "--inference_optimize_triton", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=False, + help="If inference_optimize_triton is set to True, Triton operator optimized inference is enabled.", + ) + parser.add_argument("--height", type=int, default=512, help="Height of the generated image.") + parser.add_argument("--width", type=int, default=512, help="Width of the generated image.") + parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of inference steps.") + return parser.parse_args() + + +args = parse_args() + +if args.inference_optimize: + os.environ["INFERENCE_OPTIMIZE"] = "True" +if args.inference_optimize_triton: + os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True" + + +image = pipe( + prompt, num_inference_steps=args.num_inference_steps, width=args.width, height=args.height, generator=generator +).images[0] + +if args.benchmark: + # warmup + for i in range(5): + image = pipe( + prompt, + num_inference_steps=args.num_inference_steps, + width=args.width, + height=args.height, + generator=generator, + ).images[0] + + repeat_times = 10 + paddle.device.synchronize() + starttime = datetime.datetime.now() + for i in range(repeat_times): + image = pipe( + prompt, + num_inference_steps=args.num_inference_steps, + width=args.width, + height=args.height, + generator=generator, + ).images[0] + paddle.device.synchronize() + endtime = datetime.datetime.now() + + duringtime = endtime - starttime + time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 + print("The ave end to end time : ", time_ms / repeat_times, "ms") + + cuda_mem_after_used = paddle.device.cuda.max_memory_allocated() / (1024**3) + print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB") + image.save("text_to_image_generation-stable_diffusion_3-result.png") diff --git a/ppdiffusers/ppdiffusers/models/attention.py b/ppdiffusers/ppdiffusers/models/attention.py index 8b5a9d027..700ea5f37 100644 --- a/ppdiffusers/ppdiffusers/models/attention.py +++ b/ppdiffusers/ppdiffusers/models/attention.py @@ -14,8 +14,8 @@ from typing import Any, Dict, Optional import paddle -from paddle import nn import paddle.nn.functional as F +from paddle import nn from ..utils import USE_PEFT_BACKEND from ..utils.paddle_utils import maybe_allow_in_graph @@ -92,6 +92,7 @@ def forward(self, x: paddle.Tensor, objs: paddle.Tensor) -> paddle.Tensor: return x + @maybe_allow_in_graph class JointTransformerBlock(nn.Layer): r""" @@ -112,7 +113,6 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" self.norm1 = AdaLayerNormZero(dim) - if context_norm_type == "ada_norm_continous": self.norm1_context = AdaLayerNormContinuous( dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm" @@ -161,9 +161,7 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): self._chunk_size = chunk_size self._chunk_dim = dim - def forward( - self, hidden_states: paddle.Tensor, encoder_hidden_states: paddle.Tensor, temb: paddle.Tensor - ): + def forward(self, hidden_states: paddle.Tensor, encoder_hidden_states: paddle.Tensor, temb: paddle.Tensor): norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) if self.context_pre_only: @@ -175,7 +173,8 @@ def forward( # Attention. attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, ) # Process attention outputs for the `hidden_states`. @@ -184,6 +183,11 @@ def forward( norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + # import paddlemix + # # norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp, epsilon=1e-6) + # hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( + # hidden_states, attn_output, gate_msa, scale_mlp, shift_mlp, epsilon=1e-06 + # ) if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) @@ -202,6 +206,11 @@ def forward( norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + # import paddlemix + # norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(encoder_hidden_states, scale_mlp, shift_mlp, epsilon=1e-6) + # encoder_hidden_states, norm_encoder_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( + # encoder_hidden_states, context_attn_output, c_gate_msa, c_scale_mlp, c_shift_mlp, epsilon=1e-06 + # ) if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory context_ff_output = _chunked_feed_forward( @@ -213,6 +222,7 @@ def forward( return encoder_hidden_states, hidden_states + @maybe_allow_in_graph class BasicTransformerBlock(nn.Layer): r""" diff --git a/ppdiffusers/ppdiffusers/models/attention_processor.py b/ppdiffusers/ppdiffusers/models/attention_processor.py index c93c55ae6..f2c742cfc 100644 --- a/ppdiffusers/ppdiffusers/models/attention_processor.py +++ b/ppdiffusers/ppdiffusers/models/attention_processor.py @@ -906,6 +906,7 @@ def __call__( return hidden_states + class JointAttnProcessor2_5: """Attention processor used typically in processing the SD3-like self-attention projections.""" @@ -923,7 +924,6 @@ def __call__( **kwargs, ) -> paddle.Tensor: residual = hidden_states - input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape @@ -931,7 +931,9 @@ def __call__( context_input_ndim = encoder_hidden_states.ndim if context_input_ndim == 4: batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.reshape([batch_size, channel, height * width]).transpose([0, 2, 1]) + encoder_hidden_states = encoder_hidden_states.reshape([batch_size, channel, height * width]).transpose( + [0, 2, 1] + ) batch_size = encoder_hidden_states.shape[0] @@ -945,6 +947,10 @@ def __call__( encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + # print("hidden_states_q", encoder_hidden_states_query_proj) + # print("hidden_states_K", encoder_hidden_states_key_proj) + # print("hidden_states_V", encoder_hidden_states_value_proj) + # attention query = paddle.concat([query, encoder_hidden_states_query_proj], axis=1) key = paddle.concat([key, encoder_hidden_states_key_proj], axis=1) @@ -962,6 +968,12 @@ def __call__( hidden_states = hidden_states.reshape([batch_size, -1, attn.heads * head_dim]) hidden_states = hidden_states.astype(query.dtype) + # print("hidden_states",hidden_states) + # print("encoder_hidden_states",encoder_hidden_states) + # hidden_states.fill_(0.11189012) + # print("hidden_states", hidden_states) + # print("encoder_hidden_states",norm_encoder_hidden_states) + # Split the attention outputs. hidden_states, encoder_hidden_states = ( hidden_states[:, : residual.shape[1]], @@ -970,7 +982,10 @@ def __call__( # linear proj hidden_states = attn.to_out[0](hidden_states) + # print(type(attn.to_out[0])) + # print("hidden_states", hidden_states) # dropout + hidden_states = attn.to_out[1](hidden_states) if not attn.context_pre_only: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) @@ -978,7 +993,9 @@ def __call__( if input_ndim == 4: hidden_states = hidden_states.transpose([0, 1, 3, 2]).reshape([batch_size, channel, height, width]) if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose([0, 1, 3, 2]).reshape([batch_size, channel, height, width]) + encoder_hidden_states = encoder_hidden_states.transpose([0, 1, 3, 2]).reshape( + [batch_size, channel, height, width] + ) return hidden_states, encoder_hidden_states @@ -1009,7 +1026,9 @@ def __call__( context_input_ndim = encoder_hidden_states.ndim if context_input_ndim == 4: batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.reshape([batch_size, channel, height * width]).transpose([0, 2, 1]) + encoder_hidden_states = encoder_hidden_states.reshape([batch_size, channel, height * width]).transpose( + [0, 2, 1] + ) batch_size = encoder_hidden_states.shape[0] @@ -1060,10 +1079,13 @@ def __call__( if input_ndim == 4: hidden_states = hidden_states.transpose([0, 1, 3, 2]).reshape([batch_size, channel, height, width]) if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose([0, 1, 3, 2]).reshape([batch_size, channel, height, width]) + encoder_hidden_states = encoder_hidden_states.transpose([0, 1, 3, 2]).reshape( + [batch_size, channel, height, width] + ) return hidden_states, encoder_hidden_states + class XFormersAttnAddedKVProcessor: r""" Processor for implementing memory efficient attention using xFormers. diff --git a/ppdiffusers/ppdiffusers/models/normalization.py b/ppdiffusers/ppdiffusers/models/normalization.py index 247ec0ebe..46a543469 100644 --- a/ppdiffusers/ppdiffusers/models/normalization.py +++ b/ppdiffusers/ppdiffusers/models/normalization.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numbers from typing import Dict, Optional, Tuple import paddle @@ -62,8 +61,8 @@ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None): if num_embeddings is not None: self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) else: + # print("Using None") this self.emb = None - self.silu = nn.Silu() self.linear = nn.Linear(embedding_dim, 6 * embedding_dim) norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False) @@ -81,8 +80,13 @@ def forward( if self.emb is not None: emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) + # import paddlemix + # print(self.norm.weight,self.norm.bias) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + # x = paddlemix.triton_ops.adaptive_layer_norm(x, scale_msa, shift_msa, self.norm.weight,self.norm.bias,epsilon=1e-06) + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp @@ -161,6 +165,7 @@ def forward(self, x: paddle.Tensor, emb: paddle.Tensor) -> paddle.Tensor: x = x * (1 + scale) + shift return x + class AdaLayerNormContinuous(nn.Layer): def __init__( self, @@ -188,13 +193,16 @@ def __init__( def forward(self, x: paddle.Tensor, conditioning_embedding: paddle.Tensor) -> paddle.Tensor: # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) - emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + emb = self.linear(self.silu(conditioning_embedding).cast(x.dtype)) scale, shift = paddle.chunk(emb, 2, axis=1) x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + # import paddlemix + # x = paddlemix.triton_ops.adaptive_layer_norm(x, scale, shift, self.norm.weight, self.norm.bias) return x + class RMSNorm(nn.Layer): - def __init__(self, dim, epsilon: float, elementwise_affine: bool = True): + def __init__(self, dim, epsilon: float, elementwise_affine: bool = True): super().__init__() self.epsilon = epsilon self.dim = dim @@ -214,4 +222,4 @@ def forward(self, hidden_states): norm_bias=None, epsilon=self.epsilon, begin_norm_axis=2, - ) \ No newline at end of file + ) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py new file mode 100644 index 000000000..56d238563 --- /dev/null +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# import math +# import os + +import paddle +import paddle.nn.functional as F +from paddle import nn + + +class SimplifiedSD3(nn.Layer): + def __init__( + self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int, context_pre_only=False + ): + super().__init__() + + self.context_pre_only = context_pre_only + self.num_layers = num_layers + self.dim = dim + self.bias = True + self.elementwise_affine = True + + # layer List + + # silu + matmul + add + # self.silu1 = nn.LayerList([nn.Silu() for i in range(num_layers)]) + self.silu1 = nn.Silu() + self.linear1 = nn.LayerList([nn.Linear(1536, 6 * 1536) for i in range(num_layers)]) # 1536,9216 + norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False) + self.norm1 = nn.LayerList( + [nn.LayerNorm(1536, epsilon=1e-6, **norm_elementwise_affine_kwargs) for i in range(num_layers)] + ) + + # not last layer + self.silu2_context01 = nn.LayerList([nn.Silu() for i in range(num_layers - 1)]) + self.linear_context01 = nn.LayerList([nn.Linear(1536, 6 * 1536) for i in range(num_layers - 1)]) # 1536,9216 + self.norm1_context01 = nn.LayerList( + [nn.LayerNorm(1536, epsilon=1e-6, **norm_elementwise_affine_kwargs) for i in range(num_layers - 1)] + ) # another + + # last layer + self.silu2_context0 = nn.Silu() + self.linear_context0 = nn.Linear(1536, 1536 * 2, bias_attr=self.bias) + self.norm1_context0 = nn.LayerNorm(1536, epsilon=1e-06, weight_attr=False, bias_attr=self.bias) + + # attention + self.q = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + self.k = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + self.v = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + self.eq = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + self.ek = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + self.ev = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + self.to_out_linear = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + # self.to_out = nn.LayerList([nn.Dropout(0.0) for i in range(num_layers)]) + + # not last layer + self.to_add_out_linear = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers - 1)]) + + self.ffn_norm = nn.LayerList( + [nn.LayerNorm(1536, weight_attr=False, bias_attr=False, epsilon=1e-6) for i in range(num_layers)] + ) + self.ffn1 = nn.LayerList([nn.Linear(1536, 1536 * 4) for i in range(num_layers)]) + self.ffn2 = nn.LayerList([nn.Linear(1536 * 4, 1536) for i in range(num_layers)]) + + # not last layer + self.ffn_context_norm = nn.LayerList( + [nn.LayerNorm(1536, epsilon=1e-6, weight_attr=False, bias_attr=False) for i in range(num_layers - 1)] + ) + self.ffn_context1 = nn.LayerList([nn.Linear(1536, 1536 * 4) for i in range(num_layers - 1)]) + self.ffn_context2 = nn.LayerList([nn.Linear(1536 * 4, 1536) for i in range(num_layers - 1)]) + + def forward(self, hidden_states, encoder_hidden_states, temb): + + # nnemb = nn.Silu(temb) + temb_silu1 = self.silu1(temb) + temb_silu2 = self.silu1(temb) + + for i in range(self.num_layers): + context_pre_only = i == self.num_layers - 1 + + # emb = self.linear1[i](self.silu1(temb)) + emb = self.linear1[i](temb_silu1) + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) + norm_hidden_states = self.norm1[i](hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] + # import paddlemix + # norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_msa, shift_msa, epsilon=1e-06) + + if not context_pre_only: + # emb = self.linear_context01[i](self.silu2_context01[i](temb)) + emb = self.linear_context01[i](temb_silu2) + shift_msa, scale_msa, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = emb.chunk(6, axis=1) + norm_encoder_hidden_states = ( + self.norm1_context01[i](encoder_hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] + ) + # import paddlemix + # norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(encoder_hidden_states, scale_msa, shift_msa, epsilon=1e-06) + + else: # last layer + emb = self.linear_context0(self.silu2_context0(temb).cast(encoder_hidden_states.dtype)) + scale, shift = paddle.chunk(emb, 2, axis=1) + norm_encoder_hidden_states = ( + self.norm1_context0(encoder_hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + ) + # print("self.norm1_context0.bias=",self.norm1_context0.bias) + # import paddlemix + # norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(encoder_hidden_states, scale, shift, bias=self.norm1_context0.bias) + + # -------------------------^ attention ^----------------------- + residual = norm_hidden_states + q = self.q[i](norm_hidden_states) + k = self.k[i](norm_hidden_states) + v = self.v[i](norm_hidden_states) + eq = self.eq[i](norm_encoder_hidden_states) + ek = self.ek[i](norm_encoder_hidden_states) + ev = self.ev[i](norm_encoder_hidden_states) + q = paddle.concat([q, eq], axis=1).reshape([2, -1, 24, 64]) + k = paddle.concat([k, ek], axis=1).reshape([2, -1, 24, 64]) + v = paddle.concat([v, ev], axis=1).reshape([2, -1, 24, 64]) + + norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False) + norm_hidden_states1 = norm_hidden_states1.reshape([2, -1, 1536]) + norm_hidden_states1 = norm_hidden_states1.astype(q.dtype) + + attn_output, context_attn_output = ( + norm_hidden_states1[:, : residual.shape[1]], + norm_hidden_states1[:, residual.shape[1] :], + ) + + attn_output = paddle.nn.functional.linear( + attn_output, self.to_out_linear[i].weight, self.to_out_linear[i].bias + ) + + if not context_pre_only: + context_attn_output = self.to_add_out_linear[i](context_attn_output) + + # -------------------------^ attention ^----------------------- + # ===============FF_First + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + norm_hidden_states = self.ffn_norm[i](hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + # import paddlemix + # hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( + # hidden_states, attn_output, gate_msa, scale_mlp, shift_mlp, epsilon=1e-06 + # ) + + ff_output = self.ffn1[i](norm_hidden_states) + ff_output = F.gelu(ff_output, approximate=True) + ff_output = self.ffn2[i](ff_output) + + ff_output = gate_mlp.unsqueeze(1) * ff_output + hidden_states = hidden_states + ff_output + + # ===========FF_Second + if not context_pre_only: + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + norm_encoder_hidden_states = self.ffn_context_norm[i](encoder_hidden_states) + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + ) + # import paddlemix + # encoder_hidden_states, norm_encoder_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( + # encoder_hidden_states, context_attn_output, c_gate_msa, c_scale_mlp, c_shift_mlp, epsilon=1e-06 + # ) + + context_ff_output = self.ffn_context1[i](norm_encoder_hidden_states) + context_ff_output = F.gelu(context_ff_output, approximate=True) + context_ff_output = self.ffn_context2[i](context_ff_output) + + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + else: + encoder_hidden_states = None + return encoder_hidden_states, hidden_states diff --git a/ppdiffusers/ppdiffusers/patches/paddle_patch.py b/ppdiffusers/ppdiffusers/patches/paddle_patch.py index 8a73e4315..6845ad632 100644 --- a/ppdiffusers/ppdiffusers/patches/paddle_patch.py +++ b/ppdiffusers/ppdiffusers/patches/paddle_patch.py @@ -463,7 +463,7 @@ def scaled_dot_product_attention_( pre_cache_length=0, ).transpose([0, 2, 1, 3]) elif attention_op == "flash": - with requires_grad_and_without_random(query, key, value): + with requires_grad_and_without_random(query, key, value, stop_gradient=False): output = paddle.nn.functional.scaled_dot_product_attention( query, key, From b0ea9efed356254a11f1103f6a3cc3e7b35a86eb Mon Sep 17 00:00:00 2001 From: changwenbin Date: Mon, 19 Aug 2024 05:43:28 +0000 Subject: [PATCH 02/65] optimize SD3 transformer_SD3 --- .../ppdiffusers/models/transformer_sd3.py | 232 +++++++++++++++--- 1 file changed, 201 insertions(+), 31 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 58f561d81..84d316e18 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -17,23 +17,31 @@ import paddle import paddle.nn as nn -from paddle.distributed.fleet.utils import recompute from ..configuration_utils import ConfigMixin, register_to_config + # from ..loaders import FromOriginalModelMixin, PeftAdapterMixin from ..models.attention import JointTransformerBlock from ..models.attention_processor import Attention, AttentionProcessor from ..models.modeling_utils import ModelMixin from ..models.normalization import AdaLayerNormContinuous -from ..utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers, recompute_use_reentrant, use_old_recompute +from ..utils import ( # recompute_use_reentrant,; use_old_recompute, + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed +from .simplified_sd3 import SimplifiedSD3 from .transformer_2d import Transformer2DModelOutput +# from paddle.distributed.fleet.utils import recompute + logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class SD3Transformer2DModel(ModelMixin, ConfigMixin): # , PeftAdapterMixin, FromOriginalModelMixin +class SD3Transformer2DModel(ModelMixin, ConfigMixin): # , PeftAdapterMixin, FromOriginalModelMixin """ The Transformer model introduced in Stable Diffusion 3. Reference: https://arxiv.org/abs/2403.03206 @@ -100,6 +108,21 @@ def __init__( ] ) + self.simplified_sd3 = SimplifiedSD3( + num_layers, + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.inner_dim, + # context_pre_onl, + ) + # self.simplified_sd3 = paddle.incubate.jit.inference( + # self.simplified_sd3, + # enable_new_ir=True, + # cache_static_model=False, + # exp_enable_use_cutlass=True, + # delete_pass_lists=["add_norm_fuse_pass"], + # ) + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias_attr=True) @@ -226,6 +249,27 @@ def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value + # @paddle.incubate.jit.inference( + # enable_new_ir=True, + # cache_static_model=False, + # switch_ir_optim=True, + # exp_enable_use_cutlass=True, + # delete_pass_lists=["add_norm_fuse_pass"], + # ) + # def sd3_transformer( + # self, + # hidden_states, + # encoder_hidden_states, + # temb,): + # # print(encoder_hidden_states) + # for block in self.transformer_blocks: + # encoder_hidden_states, hidden_states = block( + # hidden_states=hidden_states, + # encoder_hidden_states=encoder_hidden_states, + # temb=temb + # ) + # return encoder_hidden_states, hidden_states + def forward( self, hidden_states: paddle.Tensor, @@ -257,6 +301,11 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ + + # print(str(self)) + # with open("/cwb/wenbin/PaddleMIX/ppdiffusers/examples/inference/AibinSD3/state_dict_817.txt", "a") as time_file: + # time_file.write(str(self.state_dict().keys())) + if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) @@ -267,9 +316,7 @@ def forward( # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - logger.info( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) + logger.info("Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective.") height, width = hidden_states.shape[-2:] @@ -277,31 +324,41 @@ def forward( temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - for block in self.transformer_blocks: - if self.training and self.gradient_checkpointing and not use_old_recompute(): - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} - hidden_states = recompute( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - **ckpt_kwargs, - ) + # hidden_states_my = paddle.clone(hidden_states) + # encoder_hidden_states_my = paddle.clone(encoder_hidden_states) - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb - ) + # for block in self.transformer_blocks: + # if self.training and self.gradient_checkpointing and not use_old_recompute() and False: + # pass + # else: + # encoder_hidden_states_my, hidden_states_my = block( + # hidden_states=hidden_states_my, encoder_hidden_states=encoder_hidden_states_my, temb=temb + # ) + + # encoder_hidden_states_my, hidden_states_my = self.simplified_sd3( + # hidden_states=hidden_states_my, encoder_hidden_states=encoder_hidden_states_my, temb=temb + # ) + + # for name, param in self.simplified_sd3.named_parameters(): + # if param.requires_grad: + # print(f"Layer: {name} | Shape: {param.shape} | Values: {param.data.numpy()[:5]}...") + # paddle.device.synchronize() + + encoder_hidden_states, hidden_states = self.simplified_sd3( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + # encoder_hidden_states = None + + # print((hidden_states - hidden_states_my)) + # print(paddle.max(paddle.abs(hidden_states - hidden_states_my))) + # exit() + + # hidden_states = self.sd3_transformer(hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb) + # print(encoder_hidden_states) + # encoder_hidden_states = None + # print((hidden_states - hidden_states_my)) + # print(paddle.max(paddle.abs(hidden_states - hidden_states_my))) + # exit() hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) @@ -326,4 +383,117 @@ def custom_forward(*inputs): if not return_dict: return (output,) - return Transformer2DModelOutput(sample=output) \ No newline at end of file + return Transformer2DModelOutput(sample=output) + + @classmethod + def custom_modify_weight(cls, state_dict): + for i in range(24): + base_map_sd3 = [ + (f"linear1.{i}.weight", f"{i}.norm1.linear.weight"), + (f"linear1.{i}.bias", f"{i}.norm1.linear.bias"), + (f"q.{i}.weight", f"{i}.attn.to_q.weight"), + (f"q.{i}.bias", f"{i}.attn.to_q.bias"), + (f"k.{i}.weight", f"{i}.attn.to_k.weight"), + (f"k.{i}.bias", f"{i}.attn.to_k.bias"), + (f"v.{i}.weight", f"{i}.attn.to_v.weight"), + (f"v.{i}.bias", f"{i}.attn.to_v.bias"), + (f"ek.{i}.weight", f"{i}.attn.add_k_proj.weight"), + (f"ek.{i}.bias", f"{i}.attn.add_k_proj.bias"), + (f"ev.{i}.weight", f"{i}.attn.add_v_proj.weight"), + (f"ev.{i}.bias", f"{i}.attn.add_v_proj.bias"), + (f"eq.{i}.weight", f"{i}.attn.add_q_proj.weight"), + (f"eq.{i}.bias", f"{i}.attn.add_q_proj.bias"), + (f"to_out_linear.{i}.weight", f"{i}.attn.to_out.0.weight"), + (f"to_out_linear.{i}.bias", f"{i}.attn.to_out.0.bias"), + (f"ffn1.{i}.weight", f"{i}.ff.net.0.proj.weight"), + (f"ffn1.{i}.bias", f"{i}.ff.net.0.proj.bias"), + (f"ffn2.{i}.weight", f"{i}.ff.net.2.weight"), + (f"ffn2.{i}.bias", f"{i}.ff.net.2.bias"), + ] + if i < 23: + extra_map_sd3 = [ + (f"linear_context01.{i}.weight", f"{i}.norm1_context.linear.weight"), + (f"linear_context01.{i}.bias", f"{i}.norm1_context.linear.bias"), + (f"to_add_out_linear.{i}.weight", f"{i}.attn.to_add_out.weight"), + (f"to_add_out_linear.{i}.bias", f"{i}.attn.to_add_out.bias"), + (f"ffn_context1.{i}.weight", f"{i}.ff_context.net.0.proj.weight"), + (f"ffn_context1.{i}.bias", f"{i}.ff_context.net.0.proj.bias"), + (f"ffn_context2.{i}.weight", f"{i}.ff_context.net.2.weight"), + (f"ffn_context2.{i}.bias", f"{i}.ff_context.net.2.bias"), + ] + else: + extra_map_sd3 = [ + ("linear_context0.weight", f"{i}.norm1_context.linear.weight"), + ("linear_context0.bias", f"{i}.norm1_context.linear.bias"), + # (f"norm1_context0.{i}.bias", f"{i}.norm1_context.norm.bias"), + ] + map_sd3 = base_map_sd3 + extra_map_sd3 + + for to_, from_ in map_sd3: + if "transformer_blocks." + from_ in state_dict: + state_dict["simplified_sd3." + to_] = paddle.assign(state_dict["transformer_blocks." + from_]) + else: + print(f"Warning!!: '{from_}' not found in state_dict") + + # for i in range(24): + # if i < 23: + # map_sd3=[ + # (f"linear1.{i}.weight", f"{i}.norm1.linear.weight"), + # (f"linear1.{i}.bias", f"{i}.norm1.linear.bias"), + # (f"linear_context01.{i}.weight", f"{i}.norm1_context.linear.weight"), + # (f"linear_context01.{i}.bias", f"{i}.norm1_context.linear.bias"), + # (f"q.{i}.weight", f"{i}.attn.to_q.weight"), + # (f"q.{i}.bias", f"{i}.attn.to_q.bias"), + # (f"k.{i}.weight", f"{i}.attn.to_k.weight"), + # (f"k.{i}.bias", f"{i}.attn.to_k.bias"), + # (f"v.{i}.weight", f"{i}.attn.to_v.weight"), + # (f"v.{i}.bias", f"{i}.attn.to_v.bias"), + # (f"ek.{i}.weight", f"{i}.attn.add_k_proj.weight"), + # (f"ek.{i}.bias", f"{i}.attn.add_k_proj.bias"), + # (f"ev.{i}.weight", f"{i}.attn.add_v_proj.weight"), + # (f"ev.{i}.bias", f"{i}.attn.add_v_proj.bias"), + # (f"eq.{i}.weight", f"{i}.attn.add_q_proj.weight"), + # (f"eq.{i}.bias", f"{i}.attn.add_q_proj.bias"), + # (f"to_out_linear.{i}.weight", f"{i}.attn.to_out.0.weight"), + # (f"to_out_linear.{i}.bias", f"{i}.attn.to_out.0.bias"), + # (f"to_add_out_linear.{i}.weight", f"{i}.attn.to_add_out.weight"), + # (f"to_add_out_linear.{i}.bias", f"{i}.attn.to_add_out.bias"), + # (f"ffn1.{i}.weight", f"{i}.ff.net.0.proj.weight"), + # (f"ffn1.{i}.bias", f"{i}.ff.net.0.proj.bias"), + # (f"ffn2.{i}.weight", f"{i}.ff.net.2.weight"), + # (f"ffn2.{i}.bias", f"{i}.ff.net.2.bias"), + # (f"ffn_context1.{i}.weight", f"{i}.ff_context.net.0.proj.weight"), + # (f"ffn_context1.{i}.bias", f"{i}.ff_context.net.0.proj.bias"), + # (f"ffn_context2.{i}.weight", f"{i}.ff_context.net.2.weight"), + # (f"ffn_context2.{i}.bias", f"{i}.ff_context.net.2.bias"), + # ] + # else: + # map_sd3=[ + # (f"linear1.{i}.weight", f"{i}.norm1.linear.weight"), + # (f"linear1.{i}.bias", f"{i}.norm1.linear.bias"), + # (f"linear_context0.weight", f"{i}.norm1_context.linear.weight"), + # (f"linear_context0.bias", f"{i}.norm1_context.linear.bias"), + # (f"q.{i}.weight", f"{i}.attn.to_q.weight"), + # (f"q.{i}.bias", f"{i}.attn.to_q.bias"), + # (f"k.{i}.weight", f"{i}.attn.to_k.weight"), + # (f"k.{i}.bias", f"{i}.attn.to_k.bias"), + # (f"v.{i}.weight", f"{i}.attn.to_v.weight"), + # (f"v.{i}.bias", f"{i}.attn.to_v.bias"), + # (f"ek.{i}.weight", f"{i}.attn.add_k_proj.weight"), + # (f"ek.{i}.bias", f"{i}.attn.add_k_proj.bias"), + # (f"ev.{i}.weight", f"{i}.attn.add_v_proj.weight"), + # (f"ev.{i}.bias", f"{i}.attn.add_v_proj.bias"), + # (f"eq.{i}.weight", f"{i}.attn.add_q_proj.weight"), + # (f"eq.{i}.bias", f"{i}.attn.add_q_proj.bias"), + # (f"to_out_linear.{i}.weight", f"{i}.attn.to_out.0.weight"), + # (f"to_out_linear.{i}.bias", f"{i}.attn.to_out.0.bias"), + # (f"ffn1.{i}.weight", f"{i}.ff.net.0.proj.weight"), + # (f"ffn1.{i}.bias", f"{i}.ff.net.0.proj.bias"), + # (f"ffn2.{i}.weight", f"{i}.ff.net.2.weight"), + # (f"ffn2.{i}.bias", f"{i}.ff.net.2.bias"), + # ] + # for to_, from_ in map_sd3: + # if "transformer_blocks." + from_ in state_dict: + # state_dict["simplified_sd3." + to_] = paddle.assign(state_dict["transformer_blocks." + from_]) + # else: + # print(f"Warning!!: '{from_}' not found in state_dict") From f06a61a711de53c1d8f0bcfb1941879cf18278c7 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Mon, 19 Aug 2024 05:45:20 +0000 Subject: [PATCH 03/65] optimize SD3 transformer_SD3 --- ppdiffusers/ppdiffusers/models/transformer_sd3.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 84d316e18..403cef819 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -17,6 +17,7 @@ import paddle import paddle.nn as nn +from paddle.distributed.fleet.utils import recompute from ..configuration_utils import ConfigMixin, register_to_config @@ -25,19 +26,18 @@ from ..models.attention_processor import Attention, AttentionProcessor from ..models.modeling_utils import ModelMixin from ..models.normalization import AdaLayerNormContinuous -from ..utils import ( # recompute_use_reentrant,; use_old_recompute, +from ..utils import ( USE_PEFT_BACKEND, logging, + recompute_use_reentrant, scale_lora_layers, unscale_lora_layers, + use_old_recompute, ) from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from .simplified_sd3 import SimplifiedSD3 from .transformer_2d import Transformer2DModelOutput -# from paddle.distributed.fleet.utils import recompute - - logger = logging.get_logger(__name__) # pylint: disable=invalid-name From dcff90c68cc883913cda3612f275552104dc933b Mon Sep 17 00:00:00 2001 From: changwenbin Date: Tue, 20 Aug 2024 06:11:17 +0000 Subject: [PATCH 04/65] update SD3 --- .../ppdiffusers/models/simplified_sd3.py | 154 ++++++++++++------ .../ppdiffusers/models/transformer_sd3.py | 114 ++++--------- 2 files changed, 138 insertions(+), 130 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 56d238563..6b212eada 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -18,6 +18,9 @@ import paddle import paddle.nn.functional as F from paddle import nn +from paddle.incubate.nn.functional import fused_linear, fused_linear_activation + +optimize = True class SimplifiedSD3(nn.Layer): @@ -36,22 +39,24 @@ def __init__( # silu + matmul + add # self.silu1 = nn.LayerList([nn.Silu() for i in range(num_layers)]) - self.silu1 = nn.Silu() + self.silu = nn.Silu() self.linear1 = nn.LayerList([nn.Linear(1536, 6 * 1536) for i in range(num_layers)]) # 1536,9216 + # self.linear1 = nn.Linear(1536, 6 * 1536 * 24) + norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False) self.norm1 = nn.LayerList( [nn.LayerNorm(1536, epsilon=1e-6, **norm_elementwise_affine_kwargs) for i in range(num_layers)] ) # not last layer - self.silu2_context01 = nn.LayerList([nn.Silu() for i in range(num_layers - 1)]) + # self.silu2_context01 = nn.LayerList([nn.Silu() for i in range(num_layers - 1)]) self.linear_context01 = nn.LayerList([nn.Linear(1536, 6 * 1536) for i in range(num_layers - 1)]) # 1536,9216 self.norm1_context01 = nn.LayerList( [nn.LayerNorm(1536, epsilon=1e-6, **norm_elementwise_affine_kwargs) for i in range(num_layers - 1)] ) # another # last layer - self.silu2_context0 = nn.Silu() + # self.silu2_context0 = nn.Silu() self.linear_context0 = nn.Linear(1536, 1536 * 2, bias_attr=self.bias) self.norm1_context0 = nn.LayerNorm(1536, epsilon=1e-06, weight_attr=False, bias_attr=self.bias) @@ -84,42 +89,59 @@ def __init__( def forward(self, hidden_states, encoder_hidden_states, temb): # nnemb = nn.Silu(temb) - temb_silu1 = self.silu1(temb) - temb_silu2 = self.silu1(temb) - + temb_silu1 = self.silu(temb) + # temb_silu2 = self.silu(temb) + # emb1 = self.linear1(temb_silu1) for i in range(self.num_layers): + # emb=emb1[:,i*6*1536:(i+1)*1536*6] context_pre_only = i == self.num_layers - 1 # emb = self.linear1[i](self.silu1(temb)) emb = self.linear1[i](temb_silu1) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) - norm_hidden_states = self.norm1[i](hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] - # import paddlemix - # norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_msa, shift_msa, epsilon=1e-06) + if optimize: + import paddlemix + + norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( + hidden_states, scale_msa, shift_msa, epsilon=1e-06 + ) + else: + norm_hidden_states = self.norm1[i](hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] if not context_pre_only: # emb = self.linear_context01[i](self.silu2_context01[i](temb)) - emb = self.linear_context01[i](temb_silu2) + emb = self.linear_context01[i](temb_silu1) shift_msa, scale_msa, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = emb.chunk(6, axis=1) - norm_encoder_hidden_states = ( - self.norm1_context01[i](encoder_hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] - ) - # import paddlemix - # norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(encoder_hidden_states, scale_msa, shift_msa, epsilon=1e-06) + + if optimize: + import paddlemix + + norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( + encoder_hidden_states, scale_msa, shift_msa, epsilon=1e-06 + ) + else: + norm_encoder_hidden_states = ( + self.norm1_context01[i](encoder_hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] + ) else: # last layer - emb = self.linear_context0(self.silu2_context0(temb).cast(encoder_hidden_states.dtype)) + emb = self.linear_context0(temb_silu1.cast(encoder_hidden_states.dtype)) scale, shift = paddle.chunk(emb, 2, axis=1) - norm_encoder_hidden_states = ( - self.norm1_context0(encoder_hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] - ) - # print("self.norm1_context0.bias=",self.norm1_context0.bias) - # import paddlemix - # norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(encoder_hidden_states, scale, shift, bias=self.norm1_context0.bias) + + if optimize: + import paddlemix + + norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( + encoder_hidden_states, scale, shift, bias=self.norm1_context0.bias + ) + else: + norm_encoder_hidden_states = ( + self.norm1_context0(encoder_hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + ) # -------------------------^ attention ^----------------------- - residual = norm_hidden_states + # residual = norm_hidden_states q = self.q[i](norm_hidden_states) k = self.k[i](norm_hidden_states) v = self.v[i](norm_hidden_states) @@ -129,15 +151,19 @@ def forward(self, hidden_states, encoder_hidden_states, temb): q = paddle.concat([q, eq], axis=1).reshape([2, -1, 24, 64]) k = paddle.concat([k, ek], axis=1).reshape([2, -1, 24, 64]) v = paddle.concat([v, ev], axis=1).reshape([2, -1, 24, 64]) + # qkv = paddle.concat([q, eq, k, ek, v, ev], axis=1).reshape([2, -1, 24, 64]) + # q,k,v = paddle.split(qkv,axis=1, num_or_sections=3) norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False) norm_hidden_states1 = norm_hidden_states1.reshape([2, -1, 1536]) norm_hidden_states1 = norm_hidden_states1.astype(q.dtype) - attn_output, context_attn_output = ( - norm_hidden_states1[:, : residual.shape[1]], - norm_hidden_states1[:, residual.shape[1] :], - ) + # attn_output, context_attn_output = ( + # norm_hidden_states1[:, : residual.shape[1]], + # norm_hidden_states1[:, residual.shape[1] :], + # ) + + attn_output, context_attn_output = paddle.split(norm_hidden_states1, num_or_sections=[1024, 154], axis=1) attn_output = paddle.nn.functional.linear( attn_output, self.to_out_linear[i].weight, self.to_out_linear[i].bias @@ -148,38 +174,62 @@ def forward(self, hidden_states, encoder_hidden_states, temb): # -------------------------^ attention ^----------------------- # ===============FF_First - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = hidden_states + attn_output - norm_hidden_states = self.ffn_norm[i](hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - # import paddlemix - # hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( - # hidden_states, attn_output, gate_msa, scale_mlp, shift_mlp, epsilon=1e-06 - # ) - - ff_output = self.ffn1[i](norm_hidden_states) - ff_output = F.gelu(ff_output, approximate=True) - ff_output = self.ffn2[i](ff_output) + + if optimize: + import paddlemix + + hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( + hidden_states, attn_output, gate_msa, scale_mlp, shift_mlp, epsilon=1e-06 + ) + else: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + norm_hidden_states = self.ffn_norm[i](hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + # ff_output = self.ffn1[i](norm_hidden_states) + # ff_output = fused_linear(norm_hidden_states, self.ffn1[i].weight, self.ffn1[i].bias) + # ff_output = F.gelu(ff_output, approximate=True) + ff_output = fused_linear_activation( + norm_hidden_states, self.ffn1[i].weight, self.ffn1[i].bias, activation="gelu" + ) + # ff_output = self.ffn2[i](ff_output) + ff_output = fused_linear(ff_output, self.ffn2[i].weight, self.ffn2[i].bias) ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = hidden_states + ff_output # ===========FF_Second if not context_pre_only: - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output - encoder_hidden_states = encoder_hidden_states + context_attn_output - norm_encoder_hidden_states = self.ffn_context_norm[i](encoder_hidden_states) - norm_encoder_hidden_states = ( - norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + if optimize: + import paddlemix + + ( + encoder_hidden_states, + norm_encoder_hidden_states, + ) = paddlemix.triton_ops.fused_adaLN_scale_residual( + encoder_hidden_states, context_attn_output, c_gate_msa, c_scale_mlp, c_shift_mlp, epsilon=1e-06 + ) + else: + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + norm_encoder_hidden_states = self.ffn_context_norm[i](encoder_hidden_states) + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + ) + + # context_ff_output = self.ffn_context1[i](norm_encoder_hidden_states) + # context_ff_output = F.gelu(context_ff_output, approximate=True) + # context_ff_output = self.ffn_context2[i](context_ff_output) + context_ff_output = fused_linear_activation( + norm_encoder_hidden_states, + self.ffn_context1[i].weight, + self.ffn_context1[i].bias, + activation="gelu", + ) + context_ff_output = fused_linear( + context_ff_output, self.ffn_context2[i].weight, self.ffn_context2[i].bias ) - # import paddlemix - # encoder_hidden_states, norm_encoder_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( - # encoder_hidden_states, context_attn_output, c_gate_msa, c_scale_mlp, c_shift_mlp, epsilon=1e-06 - # ) - - context_ff_output = self.ffn_context1[i](norm_encoder_hidden_states) - context_ff_output = F.gelu(context_ff_output, approximate=True) - context_ff_output = self.ffn_context2[i](context_ff_output) encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output else: diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 403cef819..097a7d8ea 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -17,7 +17,6 @@ import paddle import paddle.nn as nn -from paddle.distributed.fleet.utils import recompute from ..configuration_utils import ConfigMixin, register_to_config @@ -26,18 +25,19 @@ from ..models.attention_processor import Attention, AttentionProcessor from ..models.modeling_utils import ModelMixin from ..models.normalization import AdaLayerNormContinuous -from ..utils import ( +from ..utils import ( # recompute_use_reentrant,; use_old_recompute, USE_PEFT_BACKEND, logging, - recompute_use_reentrant, scale_lora_layers, unscale_lora_layers, - use_old_recompute, ) from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from .simplified_sd3 import SimplifiedSD3 from .transformer_2d import Transformer2DModelOutput +# from paddle.distributed.fleet.utils import recompute + + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -115,13 +115,13 @@ def __init__( attention_head_dim=self.inner_dim, # context_pre_onl, ) - # self.simplified_sd3 = paddle.incubate.jit.inference( - # self.simplified_sd3, - # enable_new_ir=True, - # cache_static_model=False, - # exp_enable_use_cutlass=True, - # delete_pass_lists=["add_norm_fuse_pass"], - # ) + self.simplified_sd3 = paddle.incubate.jit.inference( + self.simplified_sd3, + enable_new_ir=True, + cache_static_model=False, + exp_enable_use_cutlass=True, + delete_pass_lists=["add_norm_fuse_pass"], + ) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias_attr=True) @@ -168,7 +168,9 @@ def attn_processors(self) -> Dict[str, AttentionProcessor]: # set recursively processors = {} - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + def fn_recursive_add_processors( + name: str, module: paddle.nn.Module, processors: Dict[str, AttentionProcessor] + ): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) @@ -201,7 +203,7 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + def fn_recursive_attn_processor(name: str, module: paddle.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) @@ -344,10 +346,19 @@ def forward( # print(f"Layer: {name} | Shape: {param.shape} | Values: {param.data.numpy()[:5]}...") # paddle.device.synchronize() - encoder_hidden_states, hidden_states = self.simplified_sd3( + paddle.device.synchronize() + import nvtx + + transformer_nvtx = nvtx.start_range(message="simple_transformer", color="yellow") + + hidden_states = self.simplified_sd3( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ) - # encoder_hidden_states = None + + paddle.device.synchronize() + nvtx.end_range(transformer_nvtx) + + encoder_hidden_states = None # print((hidden_states - hidden_states_my)) # print(paddle.max(paddle.abs(hidden_states - hidden_states_my))) @@ -387,6 +398,8 @@ def forward( @classmethod def custom_modify_weight(cls, state_dict): + # state_dict["simplified_sd3.linear1.weight"] = paddle.assign(state_dict["transformer_blocks.0.norm1.linear.weight"]) + # state_dict["simplified_sd3.linear1.bias"] = paddle.assign(state_dict["transformer_blocks.0.norm1.linear.bias"]) for i in range(24): base_map_sd3 = [ (f"linear1.{i}.weight", f"{i}.norm1.linear.weight"), @@ -425,7 +438,6 @@ def custom_modify_weight(cls, state_dict): extra_map_sd3 = [ ("linear_context0.weight", f"{i}.norm1_context.linear.weight"), ("linear_context0.bias", f"{i}.norm1_context.linear.bias"), - # (f"norm1_context0.{i}.bias", f"{i}.norm1_context.norm.bias"), ] map_sd3 = base_map_sd3 + extra_map_sd3 @@ -435,65 +447,11 @@ def custom_modify_weight(cls, state_dict): else: print(f"Warning!!: '{from_}' not found in state_dict") - # for i in range(24): - # if i < 23: - # map_sd3=[ - # (f"linear1.{i}.weight", f"{i}.norm1.linear.weight"), - # (f"linear1.{i}.bias", f"{i}.norm1.linear.bias"), - # (f"linear_context01.{i}.weight", f"{i}.norm1_context.linear.weight"), - # (f"linear_context01.{i}.bias", f"{i}.norm1_context.linear.bias"), - # (f"q.{i}.weight", f"{i}.attn.to_q.weight"), - # (f"q.{i}.bias", f"{i}.attn.to_q.bias"), - # (f"k.{i}.weight", f"{i}.attn.to_k.weight"), - # (f"k.{i}.bias", f"{i}.attn.to_k.bias"), - # (f"v.{i}.weight", f"{i}.attn.to_v.weight"), - # (f"v.{i}.bias", f"{i}.attn.to_v.bias"), - # (f"ek.{i}.weight", f"{i}.attn.add_k_proj.weight"), - # (f"ek.{i}.bias", f"{i}.attn.add_k_proj.bias"), - # (f"ev.{i}.weight", f"{i}.attn.add_v_proj.weight"), - # (f"ev.{i}.bias", f"{i}.attn.add_v_proj.bias"), - # (f"eq.{i}.weight", f"{i}.attn.add_q_proj.weight"), - # (f"eq.{i}.bias", f"{i}.attn.add_q_proj.bias"), - # (f"to_out_linear.{i}.weight", f"{i}.attn.to_out.0.weight"), - # (f"to_out_linear.{i}.bias", f"{i}.attn.to_out.0.bias"), - # (f"to_add_out_linear.{i}.weight", f"{i}.attn.to_add_out.weight"), - # (f"to_add_out_linear.{i}.bias", f"{i}.attn.to_add_out.bias"), - # (f"ffn1.{i}.weight", f"{i}.ff.net.0.proj.weight"), - # (f"ffn1.{i}.bias", f"{i}.ff.net.0.proj.bias"), - # (f"ffn2.{i}.weight", f"{i}.ff.net.2.weight"), - # (f"ffn2.{i}.bias", f"{i}.ff.net.2.bias"), - # (f"ffn_context1.{i}.weight", f"{i}.ff_context.net.0.proj.weight"), - # (f"ffn_context1.{i}.bias", f"{i}.ff_context.net.0.proj.bias"), - # (f"ffn_context2.{i}.weight", f"{i}.ff_context.net.2.weight"), - # (f"ffn_context2.{i}.bias", f"{i}.ff_context.net.2.bias"), - # ] - # else: - # map_sd3=[ - # (f"linear1.{i}.weight", f"{i}.norm1.linear.weight"), - # (f"linear1.{i}.bias", f"{i}.norm1.linear.bias"), - # (f"linear_context0.weight", f"{i}.norm1_context.linear.weight"), - # (f"linear_context0.bias", f"{i}.norm1_context.linear.bias"), - # (f"q.{i}.weight", f"{i}.attn.to_q.weight"), - # (f"q.{i}.bias", f"{i}.attn.to_q.bias"), - # (f"k.{i}.weight", f"{i}.attn.to_k.weight"), - # (f"k.{i}.bias", f"{i}.attn.to_k.bias"), - # (f"v.{i}.weight", f"{i}.attn.to_v.weight"), - # (f"v.{i}.bias", f"{i}.attn.to_v.bias"), - # (f"ek.{i}.weight", f"{i}.attn.add_k_proj.weight"), - # (f"ek.{i}.bias", f"{i}.attn.add_k_proj.bias"), - # (f"ev.{i}.weight", f"{i}.attn.add_v_proj.weight"), - # (f"ev.{i}.bias", f"{i}.attn.add_v_proj.bias"), - # (f"eq.{i}.weight", f"{i}.attn.add_q_proj.weight"), - # (f"eq.{i}.bias", f"{i}.attn.add_q_proj.bias"), - # (f"to_out_linear.{i}.weight", f"{i}.attn.to_out.0.weight"), - # (f"to_out_linear.{i}.bias", f"{i}.attn.to_out.0.bias"), - # (f"ffn1.{i}.weight", f"{i}.ff.net.0.proj.weight"), - # (f"ffn1.{i}.bias", f"{i}.ff.net.0.proj.bias"), - # (f"ffn2.{i}.weight", f"{i}.ff.net.2.weight"), - # (f"ffn2.{i}.bias", f"{i}.ff.net.2.bias"), - # ] - # for to_, from_ in map_sd3: - # if "transformer_blocks." + from_ in state_dict: - # state_dict["simplified_sd3." + to_] = paddle.assign(state_dict["transformer_blocks." + from_]) - # else: - # print(f"Warning!!: '{from_}' not found in state_dict") + # if i > 0: + # state_dict["simplified_sd3.linear1.weight"] = paddle.concat([state_dict["simplified_sd3.linear1.weight"], state_dict[f"transformer_blocks.{i}.norm1.linear.weight"]], axis=1) + # state_dict["simplified_sd3.linear1.bias"] = paddle.concat([state_dict["simplified_sd3.linear1.bias"],state_dict[f"transformer_blocks.{i}.norm1.linear.bias"]], axis=0) + # print("old_weight",state_dict[f"simplified_sd3.linear1.{i}.weight"]) + # print("old_bias",state_dict[f"simplified_sd3.linear1.{i}.bias"]) + + # print("weight",state_dict["simplified_sd3.linear1.weight"]) + # exit(0) From 15c5e44e3ebb14c3b3bda2d7d82b63a52d88e86d Mon Sep 17 00:00:00 2001 From: changwenbin Date: Tue, 20 Aug 2024 06:22:54 +0000 Subject: [PATCH 05/65] uodate triton &sim_SD3 --- paddlemix/triton_ops/triton_ops.py | 8 ++++---- ppdiffusers/ppdiffusers/models/simplified_sd3.py | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/paddlemix/triton_ops/triton_ops.py b/paddlemix/triton_ops/triton_ops.py index 3ade229c3..5caf47131 100644 --- a/paddlemix/triton_ops/triton_ops.py +++ b/paddlemix/triton_ops/triton_ops.py @@ -865,9 +865,9 @@ def paddle_fused_adaLN(x, mha_out, gate, hidd, scale, shift, weight, bias, epsil shift_mlp, resi_out, adaLN_out, - M, + -1, N, - seq_size, + -1, epsilon, N_npo2=N_npo2, weight_attr=weight_attr, @@ -1096,9 +1096,9 @@ def modulate(x, shift, scale): y, y, y, - M, + -1, N, - seq_size, + -1, epsilon, BLOCK_SIZE=BLOCK_SIZE, weight_attr=weight_attr, diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 6b212eada..1c1548753 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -88,9 +88,7 @@ def __init__( def forward(self, hidden_states, encoder_hidden_states, temb): - # nnemb = nn.Silu(temb) temb_silu1 = self.silu(temb) - # temb_silu2 = self.silu(temb) # emb1 = self.linear1(temb_silu1) for i in range(self.num_layers): # emb=emb1[:,i*6*1536:(i+1)*1536*6] From ab73a632b7ea5ff6a5b4b310f242aa3fe4c4a020 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Tue, 20 Aug 2024 07:16:39 +0000 Subject: [PATCH 06/65] modify temb_silu && modify nvtx --- ppdiffusers/ppdiffusers/models/simplified_sd3.py | 10 +++++----- ppdiffusers/ppdiffusers/models/transformer_sd3.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 1c1548753..7b7362bbb 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -88,14 +88,14 @@ def __init__( def forward(self, hidden_states, encoder_hidden_states, temb): - temb_silu1 = self.silu(temb) - # emb1 = self.linear1(temb_silu1) + temb_silu = self.silu(temb) + # emb1 = self.linear1(temb_silu) for i in range(self.num_layers): # emb=emb1[:,i*6*1536:(i+1)*1536*6] context_pre_only = i == self.num_layers - 1 # emb = self.linear1[i](self.silu1(temb)) - emb = self.linear1[i](temb_silu1) + emb = self.linear1[i](temb_silu) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) if optimize: @@ -109,7 +109,7 @@ def forward(self, hidden_states, encoder_hidden_states, temb): if not context_pre_only: # emb = self.linear_context01[i](self.silu2_context01[i](temb)) - emb = self.linear_context01[i](temb_silu1) + emb = self.linear_context01[i](temb_silu) shift_msa, scale_msa, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = emb.chunk(6, axis=1) if optimize: @@ -124,7 +124,7 @@ def forward(self, hidden_states, encoder_hidden_states, temb): ) else: # last layer - emb = self.linear_context0(temb_silu1.cast(encoder_hidden_states.dtype)) + emb = self.linear_context0(temb_silu.cast(encoder_hidden_states.dtype)) scale, shift = paddle.chunk(emb, 2, axis=1) if optimize: diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 097a7d8ea..ddcab0034 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -346,17 +346,17 @@ def forward( # print(f"Layer: {name} | Shape: {param.shape} | Values: {param.data.numpy()[:5]}...") # paddle.device.synchronize() - paddle.device.synchronize() - import nvtx + # paddle.device.synchronize() + # import nvtx - transformer_nvtx = nvtx.start_range(message="simple_transformer", color="yellow") + # transformer_nvtx = nvtx.start_range(message="simple_transformer", color="yellow") hidden_states = self.simplified_sd3( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ) - paddle.device.synchronize() - nvtx.end_range(transformer_nvtx) + # paddle.device.synchronize() + # nvtx.end_range(transformer_nvtx) encoder_hidden_states = None From ed2b7b173149f5b098beccfe7651ec25c06ed4c9 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Tue, 20 Aug 2024 07:52:17 +0000 Subject: [PATCH 07/65] modify linear from fused_linear --- .../ppdiffusers/models/simplified_sd3.py | 28 +++++-------------- 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 7b7362bbb..24f62eb8e 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -160,7 +160,6 @@ def forward(self, hidden_states, encoder_hidden_states, temb): # norm_hidden_states1[:, : residual.shape[1]], # norm_hidden_states1[:, residual.shape[1] :], # ) - attn_output, context_attn_output = paddle.split(norm_hidden_states1, num_or_sections=[1024, 154], axis=1) attn_output = paddle.nn.functional.linear( @@ -185,14 +184,10 @@ def forward(self, hidden_states, encoder_hidden_states, temb): norm_hidden_states = self.ffn_norm[i](hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - # ff_output = self.ffn1[i](norm_hidden_states) - # ff_output = fused_linear(norm_hidden_states, self.ffn1[i].weight, self.ffn1[i].bias) - # ff_output = F.gelu(ff_output, approximate=True) - ff_output = fused_linear_activation( - norm_hidden_states, self.ffn1[i].weight, self.ffn1[i].bias, activation="gelu" - ) - # ff_output = self.ffn2[i](ff_output) - ff_output = fused_linear(ff_output, self.ffn2[i].weight, self.ffn2[i].bias) + ff_output = self.ffn1[i](norm_hidden_states) + ff_output = fused_linear(norm_hidden_states, self.ffn1[i].weight, self.ffn1[i].bias) + ff_output = F.gelu(ff_output, approximate=True) + ff_output = self.ffn2[i](ff_output) ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = hidden_states + ff_output @@ -216,18 +211,9 @@ def forward(self, hidden_states, encoder_hidden_states, temb): norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] ) - # context_ff_output = self.ffn_context1[i](norm_encoder_hidden_states) - # context_ff_output = F.gelu(context_ff_output, approximate=True) - # context_ff_output = self.ffn_context2[i](context_ff_output) - context_ff_output = fused_linear_activation( - norm_encoder_hidden_states, - self.ffn_context1[i].weight, - self.ffn_context1[i].bias, - activation="gelu", - ) - context_ff_output = fused_linear( - context_ff_output, self.ffn_context2[i].weight, self.ffn_context2[i].bias - ) + context_ff_output = self.ffn_context1[i](norm_encoder_hidden_states) + context_ff_output = F.gelu(context_ff_output, approximate=True) + context_ff_output = self.ffn_context2[i](context_ff_output) encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output else: From f4330d3fb7cdc23fc67b6b0bcd5086b76b4c802e Mon Sep 17 00:00:00 2001 From: changwenbin Date: Tue, 20 Aug 2024 08:18:23 +0000 Subject: [PATCH 08/65] modify simplified_sd3 --- ppdiffusers/ppdiffusers/models/simplified_sd3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 24f62eb8e..65d2a879a 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -185,7 +185,6 @@ def forward(self, hidden_states, encoder_hidden_states, temb): norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ffn1[i](norm_hidden_states) - ff_output = fused_linear(norm_hidden_states, self.ffn1[i].weight, self.ffn1[i].bias) ff_output = F.gelu(ff_output, approximate=True) ff_output = self.ffn2[i](ff_output) From cc1af0f007fcbb80cee9f2c42aece450bdbee41e Mon Sep 17 00:00:00 2001 From: changwenbin Date: Tue, 20 Aug 2024 14:57:00 +0000 Subject: [PATCH 09/65] add split_concat triton kernel --- paddlemix/triton_ops/__init__.py | 2 + paddlemix/triton_ops/triton_ops.py | 162 ++++++++++++++++++ .../ppdiffusers/models/simplified_sd3.py | 47 +++-- .../ppdiffusers/models/transformer_sd3.py | 52 +++++- 4 files changed, 245 insertions(+), 18 deletions(-) diff --git a/paddlemix/triton_ops/__init__.py b/paddlemix/triton_ops/__init__.py index 76db91ab2..b6ba570ac 100644 --- a/paddlemix/triton_ops/__init__.py +++ b/paddlemix/triton_ops/__init__.py @@ -19,6 +19,7 @@ adaptive_layer_norm, fused_adaLN_scale_residual, fused_rotary_emb, + my_splcat, paddle_use_triton, rms_norm, weight_only_int8, @@ -39,6 +40,7 @@ "rms_norm", "get_dtype_str", "fused_rotary_emb", + "my_splcat", ] except: pass diff --git a/paddlemix/triton_ops/triton_ops.py b/paddlemix/triton_ops/triton_ops.py index 5caf47131..d31ee2c56 100644 --- a/paddlemix/triton_ops/triton_ops.py +++ b/paddlemix/triton_ops/triton_ops.py @@ -1567,3 +1567,165 @@ def fused_rotary_emb( outputs={"q_out": q_out, "k_out": k_out, "v_out": v_out}, ) return q_out, k_out, v_out + + +########################### adaptive layer norm ############################### +split_concat_template = ( + """ + + +std::vector ${op_name}_func( + const paddle::Tensor &x, + const paddle::Tensor &y) { + + int batch = x.dims()[0]; + + int seq_qkv = x.dims()[1]; + int seq_eqkv = y.dims()[1]; + int output_hidden = x.dims()[2] / 3; + + + auto qkv = get_tensor_ptr(x); + auto eqkv = get_tensor_ptr(y); + + + auto out0_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place()); + auto out1_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place()); + auto out2_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place()); + + auto out0 = get_tensor_ptr(out0_tensor); + auto out1 = get_tensor_ptr(out1_tensor); + auto out2 = get_tensor_ptr(out2_tensor); + + + auto run_stream = out0_tensor.stream(); + +""" + + tune_and_invoke_part + + """ + return {out0_tensor, out1_tensor, out2_tensor}; +} + +std::vector> ${op_name}_InferShape( + const std::vector& A_shape, const std::vector& B_shape) { + + std::vector out_shape = {A_shape[0], A_shape[1]+B_shape[1], A_shape[2]/3}; + + return {out_shape, out_shape, out_shape}; +} + +std::vector ${op_name}_InferDtype(const paddle::DataType& A_dtype) { + return {A_dtype, A_dtype, A_dtype}; +} + +PD_BUILD_OP(${op_name}) + .Inputs({"x", "y"}) + .Outputs({"out0_tensor", "out1_tensor", "out2_tensor"}) + .SetKernelFn(PD_KERNEL(${op_name}_func)) + .SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype)) + .SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape)); +""" +) + + +@paddle_use_triton( + custom_op_template=split_concat_template, + key=["1"], +) +def splcat_kernel( + out0, + out1, + out2, + qkv, + eqkv, + batch, + seq_qkv, + seq_eqkv, + output_hidden, # 1536 + BLOCK_SIZE: tl.constexpr, +): + + # grid = (3, batch, (seq_x + seq_y)) + out_id = tl.program_id(axis=0) + batch = tl.program_id(axis=1) + out_row = tl.program_id(axis=2) + # if out_id == 0 and out_row < seq_qkv: + # read_ptr = out_row * (output_hidden * 3) + out_id * output_hidden + x + (batch * seq_qkv * output_hidden * 3) + # elif out_id == 0 and out_row < seq_eqkv: + # read_ptr = (out_row - seq_qkv) * (output_hidden * 3) + out_id * output_hidden + y + + if out_row < seq_qkv: + read_ptr = out_id * output_hidden + out_row * 3 * output_hidden + batch * seq_qkv * output_hidden * 3 + qkv + else: + read_ptr = ( + out_id * output_hidden + + (out_row - seq_qkv) * 3 * output_hidden + + batch * seq_eqkv * output_hidden * 3 + + eqkv + ) + + read_offsets = tl.arange(0, BLOCK_SIZE) + + mask = read_offsets < output_hidden + + read_data = tl.load(read_ptr + read_offsets, mask=mask) + + real_output = out0 + if out_id == 1: + real_output = out1 + elif out_id == 2: + real_output = out2 + + write_ptr = batch * (seq_qkv + seq_eqkv) * output_hidden + out_row * output_hidden + real_output + read_offsets + + tl.store(write_ptr, read_data, mask=mask) + + +def my_splcat(x, y): + assert len(x.shape) == 3 + assert len(y.shape) == 3 + + assert x.shape[0] == y.shape[0] + assert x.shape[2] == y.shape[2] + + batch = x.shape[0] + seq_qkv = x.shape[1] + hidd_x = x.shape[2] + seq_eqkv = y.shape[1] + ouput_hidden = hidd_x // 3 + + op_name = "my_splitconcat" + + if op_name not in OpProtoHolder.instance().op_proto_map.keys(): + out0 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype) + out1 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype) + out2 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype) + grid = ("3", "batch", "seq_qkv + seq_eqkv") + + splcat_kernel[(op_name, grid)](out0, out1, out2, x, y, batch, seq_qkv, seq_eqkv, ouput_hidden, BLOCK_SIZE=2048) + + if in_dynamic_or_pir_mode(): + print(f"== we are in dynamic mode, op_name: {op_name}") + outs = _C_ops._run_custom_op( + op_name, + x, + y, + ) + return outs[0], outs[1], outs[2] + else: + print(f"== we are in dynamic to static mode, op_name: {op_name}") + helper = LayerHelper(op_name, **locals()) + inputs = { + "x": x, + "y": y, + } + out0 = helper.create_variable_for_type_inference(dtype=x.dtype) + out1 = helper.create_variable_for_type_inference(dtype=x.dtype) + out2 = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type=op_name, + inputs=inputs, + outputs={"out0_tensor": out0, "out1_tensor": out1, "out2_tensor": out2}, + ) + return out0, out1, out2 diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 65d2a879a..0c57954b0 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -61,12 +61,15 @@ def __init__( self.norm1_context0 = nn.LayerNorm(1536, epsilon=1e-06, weight_attr=False, bias_attr=self.bias) # attention - self.q = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - self.k = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - self.v = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - self.eq = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - self.ek = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - self.ev = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + # self.q = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + # self.k = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + # self.v = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + self.qkv = nn.LayerList([nn.Linear(1536, 1536 * 3) for i in range(num_layers)]) + + # self.eq = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + # self.ek = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + # self.ev = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + self.eqkv = nn.LayerList([nn.Linear(1536, 1536 * 3) for i in range(num_layers)]) self.to_out_linear = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) # self.to_out = nn.LayerList([nn.Dropout(0.0) for i in range(num_layers)]) @@ -140,15 +143,29 @@ def forward(self, hidden_states, encoder_hidden_states, temb): # -------------------------^ attention ^----------------------- # residual = norm_hidden_states - q = self.q[i](norm_hidden_states) - k = self.k[i](norm_hidden_states) - v = self.v[i](norm_hidden_states) - eq = self.eq[i](norm_encoder_hidden_states) - ek = self.ek[i](norm_encoder_hidden_states) - ev = self.ev[i](norm_encoder_hidden_states) - q = paddle.concat([q, eq], axis=1).reshape([2, -1, 24, 64]) - k = paddle.concat([k, ek], axis=1).reshape([2, -1, 24, 64]) - v = paddle.concat([v, ev], axis=1).reshape([2, -1, 24, 64]) + # q = self.q[i](norm_hidden_states) + # k = self.k[i](norm_hidden_states) + # v = self.v[i](norm_hidden_states) + qkv = self.qkv[i](norm_hidden_states) + # q,k,v = paddle.split(qkv,axis=2, num_or_sections=3) + + # eq = self.eq[i](norm_encoder_hidden_states) + # ek = self.ek[i](norm_encoder_hidden_states) + # ev = self.ev[i](norm_encoder_hidden_states) + eqkv = self.eqkv[i](norm_encoder_hidden_states) + # eq,ek,ev = paddle.split(eqkv,axis=2, num_or_sections=3) + + # q = paddle.concat([q, eq], axis=1).reshape([2, -1, 24, 64]) + # k = paddle.concat([k, ek], axis=1).reshape([2, -1, 24, 64]) + # v = paddle.concat([v, ev], axis=1).reshape([2, -1, 24, 64]) + + import paddlemix + + q, k, v = paddlemix.triton_ops.my_splcat(qkv, eqkv) + q = q.reshape([2, -1, 24, 64]) + k = k.reshape([2, -1, 24, 64]) + v = v.reshape([2, -1, 24, 64]) + # qkv = paddle.concat([q, eq, k, ek, v, ev], axis=1).reshape([2, -1, 24, 64]) # q,k,v = paddle.split(qkv,axis=1, num_or_sections=3) diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index ddcab0034..9d478f8fe 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -447,11 +447,57 @@ def custom_modify_weight(cls, state_dict): else: print(f"Warning!!: '{from_}' not found in state_dict") - # if i > 0: # state_dict["simplified_sd3.linear1.weight"] = paddle.concat([state_dict["simplified_sd3.linear1.weight"], state_dict[f"transformer_blocks.{i}.norm1.linear.weight"]], axis=1) # state_dict["simplified_sd3.linear1.bias"] = paddle.concat([state_dict["simplified_sd3.linear1.bias"],state_dict[f"transformer_blocks.{i}.norm1.linear.bias"]], axis=0) # print("old_weight",state_dict[f"simplified_sd3.linear1.{i}.weight"]) # print("old_bias",state_dict[f"simplified_sd3.linear1.{i}.bias"]) - - # print("weight",state_dict["simplified_sd3.linear1.weight"]) + state_dict[f"simplified_sd3.qkv.{i}.weight"] = paddle.assign( + paddle.concat( + [ + state_dict[f"simplified_sd3.q.{i}.weight"], + state_dict[f"simplified_sd3.k.{i}.weight"], + state_dict[f"simplified_sd3.v.{i}.weight"], + ], + axis=1, + ) + ) + state_dict[f"simplified_sd3.qkv.{i}.bias"] = paddle.assign( + paddle.concat( + [ + state_dict[f"simplified_sd3.q.{i}.bias"], + state_dict[f"simplified_sd3.q.{i}.bias"], + state_dict[f"simplified_sd3.q.{i}.bias"], + ], + axis=0, + ) + ) + state_dict[f"simplified_sd3.eqkv.{i}.weight"] = paddle.assign( + paddle.concat( + [ + state_dict[f"simplified_sd3.eq.{i}.weight"], + state_dict[f"simplified_sd3.ek.{i}.weight"], + state_dict[f"simplified_sd3.ev.{i}.weight"], + ], + axis=1, + ) + ) + state_dict[f"simplified_sd3.eqkv.{i}.bias"] = paddle.assign( + paddle.concat( + [ + state_dict[f"simplified_sd3.eq.{i}.bias"], + state_dict[f"simplified_sd3.ek.{i}.bias"], + state_dict[f"simplified_sd3.ev.{i}.bias"], + ], + axis=0, + ) + ) + print("old_weight_q", state_dict[f"simplified_sd3.eq.{i}.bias"]) + print("old_weight_k", state_dict[f"simplified_sd3.ek.{i}.bias"]) + print("old_weight_v", state_dict[f"simplified_sd3.ev.{i}.bias"]) + print( + "weight", + state_dict[f"simplified_sd3.eqkv.{i}.bias"], + ) + # exit(0) + # print("weight",state_dict["simplified_sd3.linear1.weight"]) # exit(0) From 70e6b6e4b3e0871d3c5a69be7250184e95f0c0f2 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 21 Aug 2024 03:43:16 +0000 Subject: [PATCH 10/65] modify split_concat triton kernel --- paddlemix/triton_ops/__init__.py | 4 +-- paddlemix/triton_ops/triton_ops.py | 29 +++++++------------ .../ppdiffusers/models/simplified_sd3.py | 5 +--- .../ppdiffusers/models/transformer_sd3.py | 10 ------- 4 files changed, 14 insertions(+), 34 deletions(-) diff --git a/paddlemix/triton_ops/__init__.py b/paddlemix/triton_ops/__init__.py index b6ba570ac..4c2f1691d 100644 --- a/paddlemix/triton_ops/__init__.py +++ b/paddlemix/triton_ops/__init__.py @@ -19,9 +19,9 @@ adaptive_layer_norm, fused_adaLN_scale_residual, fused_rotary_emb, - my_splcat, paddle_use_triton, rms_norm, + split_concat, weight_only_int8, ) from .triton_utils import ( @@ -40,7 +40,7 @@ "rms_norm", "get_dtype_str", "fused_rotary_emb", - "my_splcat", + "split_concat", ] except: pass diff --git a/paddlemix/triton_ops/triton_ops.py b/paddlemix/triton_ops/triton_ops.py index d31ee2c56..58c2b7ef0 100644 --- a/paddlemix/triton_ops/triton_ops.py +++ b/paddlemix/triton_ops/triton_ops.py @@ -1569,11 +1569,9 @@ def fused_rotary_emb( return q_out, k_out, v_out -########################### adaptive layer norm ############################### +########################### split concat ############################### split_concat_template = ( """ - - std::vector ${op_name}_func( const paddle::Tensor &x, const paddle::Tensor &y) { @@ -1632,7 +1630,7 @@ def fused_rotary_emb( custom_op_template=split_concat_template, key=["1"], ) -def splcat_kernel( +def split_concat_kernel( out0, out1, out2, @@ -1641,19 +1639,12 @@ def splcat_kernel( batch, seq_qkv, seq_eqkv, - output_hidden, # 1536 + output_hidden, BLOCK_SIZE: tl.constexpr, ): - - # grid = (3, batch, (seq_x + seq_y)) out_id = tl.program_id(axis=0) batch = tl.program_id(axis=1) out_row = tl.program_id(axis=2) - # if out_id == 0 and out_row < seq_qkv: - # read_ptr = out_row * (output_hidden * 3) + out_id * output_hidden + x + (batch * seq_qkv * output_hidden * 3) - # elif out_id == 0 and out_row < seq_eqkv: - # read_ptr = (out_row - seq_qkv) * (output_hidden * 3) + out_id * output_hidden + y - if out_row < seq_qkv: read_ptr = out_id * output_hidden + out_row * 3 * output_hidden + batch * seq_qkv * output_hidden * 3 + qkv else: @@ -1665,9 +1656,7 @@ def splcat_kernel( ) read_offsets = tl.arange(0, BLOCK_SIZE) - mask = read_offsets < output_hidden - read_data = tl.load(read_ptr + read_offsets, mask=mask) real_output = out0 @@ -1681,7 +1670,7 @@ def splcat_kernel( tl.store(write_ptr, read_data, mask=mask) -def my_splcat(x, y): +def split_concat(x, y): assert len(x.shape) == 3 assert len(y.shape) == 3 @@ -1693,8 +1682,10 @@ def my_splcat(x, y): hidd_x = x.shape[2] seq_eqkv = y.shape[1] ouput_hidden = hidd_x // 3 - - op_name = "my_splitconcat" + BLOCK_SIZE = triton.next_power_of_2(ouput_hidden) + op_name = "triton_split_concat" + op_name += get_dtype_str(x.dtype) + op_name += f"_{BLOCK_SIZE}" if op_name not in OpProtoHolder.instance().op_proto_map.keys(): out0 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype) @@ -1702,7 +1693,9 @@ def my_splcat(x, y): out2 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype) grid = ("3", "batch", "seq_qkv + seq_eqkv") - splcat_kernel[(op_name, grid)](out0, out1, out2, x, y, batch, seq_qkv, seq_eqkv, ouput_hidden, BLOCK_SIZE=2048) + split_concat_kernel[(op_name, grid)]( + out0, out1, out2, x, y, batch, seq_qkv, seq_eqkv, ouput_hidden, BLOCK_SIZE=2048 + ) if in_dynamic_or_pir_mode(): print(f"== we are in dynamic mode, op_name: {op_name}") diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 0c57954b0..59bcc8024 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -71,7 +71,6 @@ def __init__( # self.ev = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) self.eqkv = nn.LayerList([nn.Linear(1536, 1536 * 3) for i in range(num_layers)]) self.to_out_linear = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - # self.to_out = nn.LayerList([nn.Dropout(0.0) for i in range(num_layers)]) # not last layer self.to_add_out_linear = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers - 1)]) @@ -97,7 +96,6 @@ def forward(self, hidden_states, encoder_hidden_states, temb): # emb=emb1[:,i*6*1536:(i+1)*1536*6] context_pre_only = i == self.num_layers - 1 - # emb = self.linear1[i](self.silu1(temb)) emb = self.linear1[i](temb_silu) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) @@ -111,7 +109,6 @@ def forward(self, hidden_states, encoder_hidden_states, temb): norm_hidden_states = self.norm1[i](hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] if not context_pre_only: - # emb = self.linear_context01[i](self.silu2_context01[i](temb)) emb = self.linear_context01[i](temb_silu) shift_msa, scale_msa, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = emb.chunk(6, axis=1) @@ -161,7 +158,7 @@ def forward(self, hidden_states, encoder_hidden_states, temb): import paddlemix - q, k, v = paddlemix.triton_ops.my_splcat(qkv, eqkv) + q, k, v = paddlemix.triton_ops.split_concat(qkv, eqkv) q = q.reshape([2, -1, 24, 64]) k = k.reshape([2, -1, 24, 64]) v = v.reshape([2, -1, 24, 64]) diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 9d478f8fe..29cbd29af 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -491,13 +491,3 @@ def custom_modify_weight(cls, state_dict): axis=0, ) ) - print("old_weight_q", state_dict[f"simplified_sd3.eq.{i}.bias"]) - print("old_weight_k", state_dict[f"simplified_sd3.ek.{i}.bias"]) - print("old_weight_v", state_dict[f"simplified_sd3.ev.{i}.bias"]) - print( - "weight", - state_dict[f"simplified_sd3.eqkv.{i}.bias"], - ) - # exit(0) - # print("weight",state_dict["simplified_sd3.linear1.weight"]) - # exit(0) From 9543b118fbfa8a9fe328ce7d5a8f8a532aa713b6 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 21 Aug 2024 07:33:49 +0000 Subject: [PATCH 11/65] update --- ..._to_image_generation-stable_diffusion_3.py | 53 +++--- ppdiffusers/ppdiffusers/models/attention.py | 47 +++--- .../ppdiffusers/models/normalization.py | 23 ++- .../ppdiffusers/models/simplified_sd3.py | 97 ++++------- .../ppdiffusers/models/transformer_sd3.py | 154 ++++++++---------- 5 files changed, 181 insertions(+), 193 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index 411cb9f25..b1e266223 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -21,15 +21,6 @@ from ppdiffusers import StableDiffusion3Pipeline -pipe = StableDiffusion3Pipeline.from_pretrained( - "/root/.cache/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/stable-diffusion-3-medium-diffusers", - paddle_dtype=paddle.float16, - from_hf_hub=True, - from_diffusers=True, -) -generator = paddle.Generator().manual_seed(42) -prompt = "A cat holding a sign that says hello world" - def parse_args(): parser = argparse.ArgumentParser( @@ -50,9 +41,15 @@ def parse_args(): parser.add_argument( "--inference_optimize_triton", type=(lambda x: str(x).lower() in ["true", "1", "yes"]), - default=False, + default=True, help="If inference_optimize_triton is set to True, Triton operator optimized inference is enabled.", ) + parser.add_argument( + "--inference_optimize_origin", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=False, + help="If inference_optimize_origin is set to True, the original dynamic graph inference optimization is enabled.", + ) parser.add_argument("--height", type=int, default=512, help="Height of the generated image.") parser.add_argument("--width", type=int, default=512, help="Width of the generated image.") parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of inference steps.") @@ -65,6 +62,18 @@ def parse_args(): os.environ["INFERENCE_OPTIMIZE"] = "True" if args.inference_optimize_triton: os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True" +if args.inference_optimize_origin: + os.environ["INFERENCE_OPTIMIZE_ORIGIN"] = "True" + + +pipe = StableDiffusion3Pipeline.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + paddle_dtype=paddle.float16, + # from_hf_hub=True, + # from_diffusers=True, +) +generator = paddle.Generator().manual_seed(42) +prompt = "A cat holding a sign that says hello world" image = pipe( @@ -73,7 +82,7 @@ def parse_args(): if args.benchmark: # warmup - for i in range(5): + for i in range(3): image = pipe( prompt, num_inference_steps=args.num_inference_steps, @@ -82,10 +91,11 @@ def parse_args(): generator=generator, ).images[0] - repeat_times = 10 - paddle.device.synchronize() - starttime = datetime.datetime.now() + repeat_times = 5 + sumtime = 0.0 for i in range(repeat_times): + paddle.device.synchronize() + starttime = datetime.datetime.now() image = pipe( prompt, num_inference_steps=args.num_inference_steps, @@ -93,13 +103,14 @@ def parse_args(): height=args.height, generator=generator, ).images[0] - paddle.device.synchronize() - endtime = datetime.datetime.now() - - duringtime = endtime - starttime - time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 - print("The ave end to end time : ", time_ms / repeat_times, "ms") - + paddle.device.synchronize() + endtime = datetime.datetime.now() + duringtime = endtime - starttime + duringtime = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 + sumtime += duringtime + print("The this end to end time : ", duringtime, "ms") + + print("The ave end to end time : ", sumtime / repeat_times, "ms") cuda_mem_after_used = paddle.device.cuda.max_memory_allocated() / (1024**3) print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB") diff --git a/ppdiffusers/ppdiffusers/models/attention.py b/ppdiffusers/ppdiffusers/models/attention.py index 700ea5f37..88a1abaf0 100644 --- a/ppdiffusers/ppdiffusers/models/attention.py +++ b/ppdiffusers/ppdiffusers/models/attention.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Any, Dict, Optional import paddle @@ -178,16 +179,18 @@ def forward(self, hidden_states: paddle.Tensor, encoder_hidden_states: paddle.Te ) # Process attention outputs for the `hidden_states`. - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = hidden_states + attn_output - - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - # import paddlemix - # # norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp, epsilon=1e-6) - # hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( - # hidden_states, attn_output, gate_msa, scale_mlp, shift_mlp, epsilon=1e-06 - # ) + if os.getenv("INFERENCE_OPTIMIZE_TRITON"): + import paddlemix + + hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( + hidden_states, attn_output, gate_msa, scale_mlp, shift_mlp, epsilon=1e-06 + ) + else: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) @@ -201,16 +204,20 @@ def forward(self, hidden_states: paddle.Tensor, encoder_hidden_states: paddle.Te if self.context_pre_only: encoder_hidden_states = None else: - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output - encoder_hidden_states = encoder_hidden_states + context_attn_output - - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] - # import paddlemix - # norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(encoder_hidden_states, scale_mlp, shift_mlp, epsilon=1e-6) - # encoder_hidden_states, norm_encoder_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( - # encoder_hidden_states, context_attn_output, c_gate_msa, c_scale_mlp, c_shift_mlp, epsilon=1e-06 - # ) + if os.getenv("INFERENCE_OPTIMIZE_TRITON"): + import paddlemix + + encoder_hidden_states, norm_encoder_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( + encoder_hidden_states, context_attn_output, c_gate_msa, c_scale_mlp, c_shift_mlp, epsilon=1e-06 + ) + else: + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + ) + if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory context_ff_output = _chunked_feed_forward( diff --git a/ppdiffusers/ppdiffusers/models/normalization.py b/ppdiffusers/ppdiffusers/models/normalization.py index 46a543469..10c0318b0 100644 --- a/ppdiffusers/ppdiffusers/models/normalization.py +++ b/ppdiffusers/ppdiffusers/models/normalization.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Dict, Optional, Tuple import paddle @@ -82,10 +83,15 @@ def forward( emb = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) - # import paddlemix - # print(self.norm.weight,self.norm.bias) - x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] - # x = paddlemix.triton_ops.adaptive_layer_norm(x, scale_msa, shift_msa, self.norm.weight,self.norm.bias,epsilon=1e-06) + + if os.getenv("INFERENCE_OPTIMIZE_TRITON"): + import paddlemix + + x = paddlemix.triton_ops.adaptive_layer_norm( + x, scale_msa, shift_msa, self.norm.weight, self.norm.bias, epsilon=1e-06 + ) + else: + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa, shift_mlp, scale_mlp, gate_mlp @@ -195,9 +201,12 @@ def forward(self, x: paddle.Tensor, conditioning_embedding: paddle.Tensor) -> pa # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) emb = self.linear(self.silu(conditioning_embedding).cast(x.dtype)) scale, shift = paddle.chunk(emb, 2, axis=1) - x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] - # import paddlemix - # x = paddlemix.triton_ops.adaptive_layer_norm(x, scale, shift, self.norm.weight, self.norm.bias) + if os.getenv("INFERENCE_OPTIMIZE_TRITON"): + import paddlemix + + x = paddlemix.triton_ops.adaptive_layer_norm(x, scale, shift, self.norm.weight, self.norm.bias) + else: + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] return x diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 59bcc8024..ec7a02c45 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -13,7 +13,7 @@ # limitations under the License. # import math -# import os +import os import paddle import paddle.nn.functional as F @@ -33,55 +33,38 @@ def __init__( self.num_layers = num_layers self.dim = dim self.bias = True - self.elementwise_affine = True - - # layer List + norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False) - # silu + matmul + add # self.silu1 = nn.LayerList([nn.Silu() for i in range(num_layers)]) self.silu = nn.Silu() self.linear1 = nn.LayerList([nn.Linear(1536, 6 * 1536) for i in range(num_layers)]) # 1536,9216 # self.linear1 = nn.Linear(1536, 6 * 1536 * 24) - - norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False) self.norm1 = nn.LayerList( [nn.LayerNorm(1536, epsilon=1e-6, **norm_elementwise_affine_kwargs) for i in range(num_layers)] ) - - # not last layer # self.silu2_context01 = nn.LayerList([nn.Silu() for i in range(num_layers - 1)]) self.linear_context01 = nn.LayerList([nn.Linear(1536, 6 * 1536) for i in range(num_layers - 1)]) # 1536,9216 self.norm1_context01 = nn.LayerList( [nn.LayerNorm(1536, epsilon=1e-6, **norm_elementwise_affine_kwargs) for i in range(num_layers - 1)] - ) # another - - # last layer + ) # self.silu2_context0 = nn.Silu() self.linear_context0 = nn.Linear(1536, 1536 * 2, bias_attr=self.bias) self.norm1_context0 = nn.LayerNorm(1536, epsilon=1e-06, weight_attr=False, bias_attr=self.bias) - - # attention - # self.q = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - # self.k = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - # self.v = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + self.q = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + self.k = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + self.v = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) self.qkv = nn.LayerList([nn.Linear(1536, 1536 * 3) for i in range(num_layers)]) - - # self.eq = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - # self.ek = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - # self.ev = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + self.eq = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + self.ek = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) + self.ev = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) self.eqkv = nn.LayerList([nn.Linear(1536, 1536 * 3) for i in range(num_layers)]) self.to_out_linear = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - - # not last layer self.to_add_out_linear = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers - 1)]) - self.ffn_norm = nn.LayerList( [nn.LayerNorm(1536, weight_attr=False, bias_attr=False, epsilon=1e-6) for i in range(num_layers)] ) self.ffn1 = nn.LayerList([nn.Linear(1536, 1536 * 4) for i in range(num_layers)]) self.ffn2 = nn.LayerList([nn.Linear(1536 * 4, 1536) for i in range(num_layers)]) - - # not last layer self.ffn_context_norm = nn.LayerList( [nn.LayerNorm(1536, epsilon=1e-6, weight_attr=False, bias_attr=False) for i in range(num_layers - 1)] ) @@ -99,7 +82,7 @@ def forward(self, hidden_states, encoder_hidden_states, temb): emb = self.linear1[i](temb_silu) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) - if optimize: + if os.getenv("INFERENCE_OPTIMIZE_TRITON"): import paddlemix norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( @@ -112,7 +95,7 @@ def forward(self, hidden_states, encoder_hidden_states, temb): emb = self.linear_context01[i](temb_silu) shift_msa, scale_msa, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = emb.chunk(6, axis=1) - if optimize: + if os.getenv("INFERENCE_OPTIMIZE_TRITON"): import paddlemix norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( @@ -123,11 +106,11 @@ def forward(self, hidden_states, encoder_hidden_states, temb): self.norm1_context01[i](encoder_hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] ) - else: # last layer + else: emb = self.linear_context0(temb_silu.cast(encoder_hidden_states.dtype)) scale, shift = paddle.chunk(emb, 2, axis=1) - if optimize: + if os.getenv("INFERENCE_OPTIMIZE_TRITON"): import paddlemix norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( @@ -138,42 +121,34 @@ def forward(self, hidden_states, encoder_hidden_states, temb): self.norm1_context0(encoder_hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] ) - # -------------------------^ attention ^----------------------- - # residual = norm_hidden_states - # q = self.q[i](norm_hidden_states) - # k = self.k[i](norm_hidden_states) - # v = self.v[i](norm_hidden_states) - qkv = self.qkv[i](norm_hidden_states) - # q,k,v = paddle.split(qkv,axis=2, num_or_sections=3) + if os.getenv("INFERENCE_OPTIMIZE_TRITON"): + qkv = self.qkv[i](norm_hidden_states) + eqkv = self.eqkv[i](norm_encoder_hidden_states) - # eq = self.eq[i](norm_encoder_hidden_states) - # ek = self.ek[i](norm_encoder_hidden_states) - # ev = self.ev[i](norm_encoder_hidden_states) - eqkv = self.eqkv[i](norm_encoder_hidden_states) - # eq,ek,ev = paddle.split(eqkv,axis=2, num_or_sections=3) - - # q = paddle.concat([q, eq], axis=1).reshape([2, -1, 24, 64]) - # k = paddle.concat([k, ek], axis=1).reshape([2, -1, 24, 64]) - # v = paddle.concat([v, ev], axis=1).reshape([2, -1, 24, 64]) + import paddlemix - import paddlemix + q, k, v = paddlemix.triton_ops.split_concat(qkv, eqkv) + q = q.reshape([2, -1, 24, 64]) + k = k.reshape([2, -1, 24, 64]) + v = v.reshape([2, -1, 24, 64]) + else: + # residual = norm_hidden_states + q = self.q[i](norm_hidden_states) + k = self.k[i](norm_hidden_states) + v = self.v[i](norm_hidden_states) - q, k, v = paddlemix.triton_ops.split_concat(qkv, eqkv) - q = q.reshape([2, -1, 24, 64]) - k = k.reshape([2, -1, 24, 64]) - v = v.reshape([2, -1, 24, 64]) + eq = self.eq[i](norm_encoder_hidden_states) + ek = self.ek[i](norm_encoder_hidden_states) + ev = self.ev[i](norm_encoder_hidden_states) - # qkv = paddle.concat([q, eq, k, ek, v, ev], axis=1).reshape([2, -1, 24, 64]) - # q,k,v = paddle.split(qkv,axis=1, num_or_sections=3) + q = paddle.concat([q, eq], axis=1).reshape([2, -1, 24, 64]) + k = paddle.concat([k, ek], axis=1).reshape([2, -1, 24, 64]) + v = paddle.concat([v, ev], axis=1).reshape([2, -1, 24, 64]) norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False) norm_hidden_states1 = norm_hidden_states1.reshape([2, -1, 1536]) norm_hidden_states1 = norm_hidden_states1.astype(q.dtype) - # attn_output, context_attn_output = ( - # norm_hidden_states1[:, : residual.shape[1]], - # norm_hidden_states1[:, residual.shape[1] :], - # ) attn_output, context_attn_output = paddle.split(norm_hidden_states1, num_or_sections=[1024, 154], axis=1) attn_output = paddle.nn.functional.linear( @@ -183,10 +158,7 @@ def forward(self, hidden_states, encoder_hidden_states, temb): if not context_pre_only: context_attn_output = self.to_add_out_linear[i](context_attn_output) - # -------------------------^ attention ^----------------------- - # ===============FF_First - - if optimize: + if os.getenv("INFERENCE_OPTIMIZE_TRITON"): import paddlemix hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( @@ -205,9 +177,8 @@ def forward(self, hidden_states, encoder_hidden_states, temb): ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = hidden_states + ff_output - # ===========FF_Second if not context_pre_only: - if optimize: + if os.getenv("INFERENCE_OPTIMIZE_TRITON"): import paddlemix ( diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 29cbd29af..eda706fba 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Any, Dict, Optional, Union import paddle @@ -94,6 +95,8 @@ def __init__( ) self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim) + self.inference_optimize = os.getenv("INFERENCE_OPTIMIZE") == "True" + self.inference_optimize_origin = os.getenv("INFERENCE_OPTIMIZE_ORIGIN") == "True" # `attention_head_dim` is doubled to account for the mixing. # It needs to crafted when we get the actual checkpoints. self.transformer_blocks = nn.LayerList( @@ -107,21 +110,29 @@ def __init__( for i in range(self.config.num_layers) ] ) - - self.simplified_sd3 = SimplifiedSD3( - num_layers, - dim=self.inner_dim, - num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.inner_dim, - # context_pre_onl, - ) - self.simplified_sd3 = paddle.incubate.jit.inference( - self.simplified_sd3, - enable_new_ir=True, - cache_static_model=False, - exp_enable_use_cutlass=True, - delete_pass_lists=["add_norm_fuse_pass"], - ) + if self.inference_optimize: + self.simplified_sd3 = SimplifiedSD3( + num_layers, + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.inner_dim, + # context_pre_onl, + ) + self.simplified_sd3 = paddle.incubate.jit.inference( + self.simplified_sd3, + enable_new_ir=True, + cache_static_model=False, + exp_enable_use_cutlass=True, + delete_pass_lists=["add_norm_fuse_pass"], + ) + if self.inference_optimize_origin: + self.sd3_origin_transformer = paddle.incubate.jit.inference( + self.sd3_origin_transformer, + enable_new_ir=True, + cache_static_model=False, + exp_enable_use_cutlass=True, + delete_pass_lists=["add_norm_fuse_pass"], + ) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias_attr=True) @@ -251,26 +262,17 @@ def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value - # @paddle.incubate.jit.inference( - # enable_new_ir=True, - # cache_static_model=False, - # switch_ir_optim=True, - # exp_enable_use_cutlass=True, - # delete_pass_lists=["add_norm_fuse_pass"], - # ) - # def sd3_transformer( - # self, - # hidden_states, - # encoder_hidden_states, - # temb,): - # # print(encoder_hidden_states) - # for block in self.transformer_blocks: - # encoder_hidden_states, hidden_states = block( - # hidden_states=hidden_states, - # encoder_hidden_states=encoder_hidden_states, - # temb=temb - # ) - # return encoder_hidden_states, hidden_states + def sd3_origin_transformer( + self, + hidden_states, + encoder_hidden_states, + temb, + ): + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + return encoder_hidden_states, hidden_states def forward( self, @@ -326,50 +328,44 @@ def forward( temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - # hidden_states_my = paddle.clone(hidden_states) - # encoder_hidden_states_my = paddle.clone(encoder_hidden_states) - - # for block in self.transformer_blocks: - # if self.training and self.gradient_checkpointing and not use_old_recompute() and False: - # pass - # else: - # encoder_hidden_states_my, hidden_states_my = block( - # hidden_states=hidden_states_my, encoder_hidden_states=encoder_hidden_states_my, temb=temb - # ) - - # encoder_hidden_states_my, hidden_states_my = self.simplified_sd3( - # hidden_states=hidden_states_my, encoder_hidden_states=encoder_hidden_states_my, temb=temb - # ) - - # for name, param in self.simplified_sd3.named_parameters(): - # if param.requires_grad: - # print(f"Layer: {name} | Shape: {param.shape} | Values: {param.data.numpy()[:5]}...") - # paddle.device.synchronize() - - # paddle.device.synchronize() - # import nvtx - - # transformer_nvtx = nvtx.start_range(message="simple_transformer", color="yellow") - - hidden_states = self.simplified_sd3( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb - ) - - # paddle.device.synchronize() - # nvtx.end_range(transformer_nvtx) + if self.inference_optimize: + hidden_states = self.simplified_sd3( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + encoder_hidden_states = None - encoder_hidden_states = None + elif self.inference_optimize_origin: + hidden_states = self.sd3_origin_transformer( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + encoder_hidden_states = None - # print((hidden_states - hidden_states_my)) - # print(paddle.max(paddle.abs(hidden_states - hidden_states_my))) - # exit() + else: + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing and not use_old_recompute(): + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} + hidden_states = recompute( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + **ckpt_kwargs, + ) - # hidden_states = self.sd3_transformer(hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb) - # print(encoder_hidden_states) - # encoder_hidden_states = None - # print((hidden_states - hidden_states_my)) - # print(paddle.max(paddle.abs(hidden_states - hidden_states_my))) - # exit() + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) @@ -398,8 +394,6 @@ def forward( @classmethod def custom_modify_weight(cls, state_dict): - # state_dict["simplified_sd3.linear1.weight"] = paddle.assign(state_dict["transformer_blocks.0.norm1.linear.weight"]) - # state_dict["simplified_sd3.linear1.bias"] = paddle.assign(state_dict["transformer_blocks.0.norm1.linear.bias"]) for i in range(24): base_map_sd3 = [ (f"linear1.{i}.weight", f"{i}.norm1.linear.weight"), @@ -447,10 +441,6 @@ def custom_modify_weight(cls, state_dict): else: print(f"Warning!!: '{from_}' not found in state_dict") - # state_dict["simplified_sd3.linear1.weight"] = paddle.concat([state_dict["simplified_sd3.linear1.weight"], state_dict[f"transformer_blocks.{i}.norm1.linear.weight"]], axis=1) - # state_dict["simplified_sd3.linear1.bias"] = paddle.concat([state_dict["simplified_sd3.linear1.bias"],state_dict[f"transformer_blocks.{i}.norm1.linear.bias"]], axis=0) - # print("old_weight",state_dict[f"simplified_sd3.linear1.{i}.weight"]) - # print("old_bias",state_dict[f"simplified_sd3.linear1.{i}.bias"]) state_dict[f"simplified_sd3.qkv.{i}.weight"] = paddle.assign( paddle.concat( [ From 357b75a98276c97129ebd50c124c107877f2487e Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 21 Aug 2024 08:02:00 +0000 Subject: [PATCH 12/65] update transformer_sd3 --- ..._to_image_generation-stable_diffusion_3.py | 2 +- .../ppdiffusers/models/transformer_sd3.py | 62 +++++++++---------- 2 files changed, 31 insertions(+), 33 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index b1e266223..107d132e1 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -41,7 +41,7 @@ def parse_args(): parser.add_argument( "--inference_optimize_triton", type=(lambda x: str(x).lower() in ["true", "1", "yes"]), - default=True, + default=False, help="If inference_optimize_triton is set to True, Triton operator optimized inference is enabled.", ) parser.add_argument( diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index eda706fba..f1a63310a 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -18,6 +18,7 @@ import paddle import paddle.nn as nn +from paddle.distributed.fleet.utils import recompute from ..configuration_utils import ConfigMixin, register_to_config @@ -26,19 +27,18 @@ from ..models.attention_processor import Attention, AttentionProcessor from ..models.modeling_utils import ModelMixin from ..models.normalization import AdaLayerNormContinuous -from ..utils import ( # recompute_use_reentrant,; use_old_recompute, +from ..utils import ( USE_PEFT_BACKEND, logging, + recompute_use_reentrant, scale_lora_layers, unscale_lora_layers, + use_old_recompute, ) from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from .simplified_sd3 import SimplifiedSD3 from .transformer_2d import Transformer2DModelOutput -# from paddle.distributed.fleet.utils import recompute - - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -269,9 +269,29 @@ def sd3_origin_transformer( temb, ): for block in self.transformer_blocks: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb - ) + if self.training and self.gradient_checkpointing and not use_old_recompute(): + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} + hidden_states = recompute( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + **ckpt_kwargs, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) return encoder_hidden_states, hidden_states def forward( @@ -341,31 +361,9 @@ def forward( encoder_hidden_states = None else: - for block in self.transformer_blocks: - if self.training and self.gradient_checkpointing and not use_old_recompute(): - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} - hidden_states = recompute( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - **ckpt_kwargs, - ) - - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb - ) + encoder_hidden_states, hidden_states = self.sd3_origin_transformer( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) From f54bf840f224fd98e26464af847741ebac5a88cd Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 21 Aug 2024 08:11:07 +0000 Subject: [PATCH 13/65] update transformer_sd3 --- ppdiffusers/ppdiffusers/models/transformer_sd3.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index f1a63310a..109285121 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -326,10 +326,6 @@ def forward( `tuple` where the first element is the sample tensor. """ - # print(str(self)) - # with open("/cwb/wenbin/PaddleMIX/ppdiffusers/examples/inference/AibinSD3/state_dict_817.txt", "a") as time_file: - # time_file.write(str(self.state_dict().keys())) - if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) From 3245b2fb9f292d15c613933b5baa11b78b50e565 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 21 Aug 2024 13:57:02 +0000 Subject: [PATCH 14/65] update triton & simplified_sd3 --- paddlemix/triton_ops/triton_ops.py | 4 +- .../ppdiffusers/models/simplified_sd3.py | 75 ++++++++++--------- 2 files changed, 43 insertions(+), 36 deletions(-) diff --git a/paddlemix/triton_ops/triton_ops.py b/paddlemix/triton_ops/triton_ops.py index 58c2b7ef0..bb6fead29 100644 --- a/paddlemix/triton_ops/triton_ops.py +++ b/paddlemix/triton_ops/triton_ops.py @@ -1072,7 +1072,7 @@ def modulate(x, shift, scale): M = x.shape[0] * x.shape[1] N = x.shape[2] seq_size = x.shape[1] - BLOCK_SIZE = min(1024, triton.next_power_of_2(N)) + BLOCK_SIZE = 2048#min(1024, triton.next_power_of_2(N)) op_name = "triton_adaptive_layer_norm" op_name += get_dtype_str(x.dtype) @@ -1683,7 +1683,7 @@ def split_concat(x, y): seq_eqkv = y.shape[1] ouput_hidden = hidd_x // 3 BLOCK_SIZE = triton.next_power_of_2(ouput_hidden) - op_name = "triton_split_concat" + op_name = "split_concat" op_name += get_dtype_str(x.dtype) op_name += f"_{BLOCK_SIZE}" diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index ec7a02c45..ee0c8f256 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -37,46 +37,46 @@ def __init__( # self.silu1 = nn.LayerList([nn.Silu() for i in range(num_layers)]) self.silu = nn.Silu() - self.linear1 = nn.LayerList([nn.Linear(1536, 6 * 1536) for i in range(num_layers)]) # 1536,9216 - # self.linear1 = nn.Linear(1536, 6 * 1536 * 24) + self.linear1 = nn.LayerList([nn.Linear(self.dim, 6 * self.dim) for i in range(num_layers)]) # 1536,9216 + # self.linear1 = nn.Linear(self.dim, 6 * self.dim * 24) self.norm1 = nn.LayerList( - [nn.LayerNorm(1536, epsilon=1e-6, **norm_elementwise_affine_kwargs) for i in range(num_layers)] + [nn.LayerNorm(self.dim, epsilon=1e-6, **norm_elementwise_affine_kwargs) for i in range(num_layers)] ) # self.silu2_context01 = nn.LayerList([nn.Silu() for i in range(num_layers - 1)]) - self.linear_context01 = nn.LayerList([nn.Linear(1536, 6 * 1536) for i in range(num_layers - 1)]) # 1536,9216 + self.linear_context01 = nn.LayerList([nn.Linear(self.dim, 6 * self.dim) for i in range(num_layers - 1)]) # 1536,9216 self.norm1_context01 = nn.LayerList( - [nn.LayerNorm(1536, epsilon=1e-6, **norm_elementwise_affine_kwargs) for i in range(num_layers - 1)] + [nn.LayerNorm(self.dim, epsilon=1e-6, **norm_elementwise_affine_kwargs) for i in range(num_layers - 1)] ) # self.silu2_context0 = nn.Silu() - self.linear_context0 = nn.Linear(1536, 1536 * 2, bias_attr=self.bias) - self.norm1_context0 = nn.LayerNorm(1536, epsilon=1e-06, weight_attr=False, bias_attr=self.bias) - self.q = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - self.k = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - self.v = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - self.qkv = nn.LayerList([nn.Linear(1536, 1536 * 3) for i in range(num_layers)]) - self.eq = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - self.ek = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - self.ev = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - self.eqkv = nn.LayerList([nn.Linear(1536, 1536 * 3) for i in range(num_layers)]) - self.to_out_linear = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers)]) - self.to_add_out_linear = nn.LayerList([nn.Linear(1536, 1536) for i in range(num_layers - 1)]) + self.linear_context0 = nn.Linear(self.dim, self.dim * 2, bias_attr=self.bias) + self.norm1_context0 = nn.LayerNorm(self.dim, epsilon=1e-06, weight_attr=False, bias_attr=self.bias) + self.q = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) + self.k = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) + self.v = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) + self.qkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) + self.eq = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) + self.ek = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) + self.ev = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) + self.eqkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) + self.to_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) + self.to_add_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers - 1)]) self.ffn_norm = nn.LayerList( - [nn.LayerNorm(1536, weight_attr=False, bias_attr=False, epsilon=1e-6) for i in range(num_layers)] + [nn.LayerNorm(self.dim, weight_attr=False, bias_attr=False, epsilon=1e-6) for i in range(num_layers)] ) - self.ffn1 = nn.LayerList([nn.Linear(1536, 1536 * 4) for i in range(num_layers)]) - self.ffn2 = nn.LayerList([nn.Linear(1536 * 4, 1536) for i in range(num_layers)]) + self.ffn1 = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers)]) + self.ffn2 = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers)]) self.ffn_context_norm = nn.LayerList( - [nn.LayerNorm(1536, epsilon=1e-6, weight_attr=False, bias_attr=False) for i in range(num_layers - 1)] + [nn.LayerNorm(self.dim, epsilon=1e-6, weight_attr=False, bias_attr=False) for i in range(num_layers - 1)] ) - self.ffn_context1 = nn.LayerList([nn.Linear(1536, 1536 * 4) for i in range(num_layers - 1)]) - self.ffn_context2 = nn.LayerList([nn.Linear(1536 * 4, 1536) for i in range(num_layers - 1)]) + self.ffn_context1 = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers - 1)]) + self.ffn_context2 = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers - 1)]) def forward(self, hidden_states, encoder_hidden_states, temb): - + print("--------------------this is simplified_sd3------------------------") temb_silu = self.silu(temb) # emb1 = self.linear1(temb_silu) for i in range(self.num_layers): - # emb=emb1[:,i*6*1536:(i+1)*1536*6] + # emb=emb1[:,i*6*self.dim:(i+1)*self.dim*6] context_pre_only = i == self.num_layers - 1 emb = self.linear1[i](temb_silu) @@ -133,23 +133,30 @@ def forward(self, hidden_states, encoder_hidden_states, temb): v = v.reshape([2, -1, 24, 64]) else: # residual = norm_hidden_states - q = self.q[i](norm_hidden_states) - k = self.k[i](norm_hidden_states) - v = self.v[i](norm_hidden_states) + # q = self.q[i](norm_hidden_states) + # k = self.k[i](norm_hidden_states) + # v = self.v[i](norm_hidden_states) + qkv = self.qkv[i](norm_hidden_states) + q,k,v = paddle.split(qkv,axis=2, num_or_sections=3) - eq = self.eq[i](norm_encoder_hidden_states) - ek = self.ek[i](norm_encoder_hidden_states) - ev = self.ev[i](norm_encoder_hidden_states) + # eq = self.eq[i](norm_encoder_hidden_states) + # ek = self.ek[i](norm_encoder_hidden_states) + # ev = self.ev[i](norm_encoder_hidden_states) + eqkv = self.eqkv[i](norm_encoder_hidden_states) + eq,ek,ev = paddle.split(eqkv,axis=2, num_or_sections=3) + + q = paddle.concat([q, eq], axis=1).reshape([2, -1, 24, 64]) k = paddle.concat([k, ek], axis=1).reshape([2, -1, 24, 64]) v = paddle.concat([v, ev], axis=1).reshape([2, -1, 24, 64]) - + print(q.shape, k.shape, v.shape) + exit(0) norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False) - norm_hidden_states1 = norm_hidden_states1.reshape([2, -1, 1536]) + norm_hidden_states1 = norm_hidden_states1.reshape([2, -1, self.dim]) norm_hidden_states1 = norm_hidden_states1.astype(q.dtype) - attn_output, context_attn_output = paddle.split(norm_hidden_states1, num_or_sections=[1024, 154], axis=1) + attn_output, context_attn_output = paddle.split(norm_hidden_states1, num_or_sections=[hidden_states.shape[1] , encoder_hidden_states.shape[1]], axis=1) attn_output = paddle.nn.functional.linear( attn_output, self.to_out_linear[i].weight, self.to_out_linear[i].bias From 5516df672d0e5545f42324b852e191e26761ec65 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 22 Aug 2024 03:07:47 +0000 Subject: [PATCH 15/65] update simplified_sd3 --- ppdiffusers/ppdiffusers/models/simplified_sd3.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index ee0c8f256..0df913ea7 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -35,19 +35,16 @@ def __init__( self.bias = True norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False) - # self.silu1 = nn.LayerList([nn.Silu() for i in range(num_layers)]) self.silu = nn.Silu() - self.linear1 = nn.LayerList([nn.Linear(self.dim, 6 * self.dim) for i in range(num_layers)]) # 1536,9216 + self.linear1 = nn.LayerList([nn.Linear(self.dim, 6 * self.dim) for i in range(num_layers)]) # self.linear1 = nn.Linear(self.dim, 6 * self.dim * 24) self.norm1 = nn.LayerList( [nn.LayerNorm(self.dim, epsilon=1e-6, **norm_elementwise_affine_kwargs) for i in range(num_layers)] ) - # self.silu2_context01 = nn.LayerList([nn.Silu() for i in range(num_layers - 1)]) - self.linear_context01 = nn.LayerList([nn.Linear(self.dim, 6 * self.dim) for i in range(num_layers - 1)]) # 1536,9216 + self.linear_context01 = nn.LayerList([nn.Linear(self.dim, 6 * self.dim) for i in range(num_layers - 1)]) self.norm1_context01 = nn.LayerList( [nn.LayerNorm(self.dim, epsilon=1e-6, **norm_elementwise_affine_kwargs) for i in range(num_layers - 1)] ) - # self.silu2_context0 = nn.Silu() self.linear_context0 = nn.Linear(self.dim, self.dim * 2, bias_attr=self.bias) self.norm1_context0 = nn.LayerNorm(self.dim, epsilon=1e-06, weight_attr=False, bias_attr=self.bias) self.q = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) @@ -150,8 +147,7 @@ def forward(self, hidden_states, encoder_hidden_states, temb): q = paddle.concat([q, eq], axis=1).reshape([2, -1, 24, 64]) k = paddle.concat([k, ek], axis=1).reshape([2, -1, 24, 64]) v = paddle.concat([v, ev], axis=1).reshape([2, -1, 24, 64]) - print(q.shape, k.shape, v.shape) - exit(0) + norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False) norm_hidden_states1 = norm_hidden_states1.reshape([2, -1, self.dim]) norm_hidden_states1 = norm_hidden_states1.astype(q.dtype) From 874d5d726589a57085633456145168aa34d9f7ee Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 22 Aug 2024 03:36:03 +0000 Subject: [PATCH 16/65] update simplified_sd3 --- .../ppdiffusers/models/simplified_sd3.py | 22 ++++++++----------- .../ppdiffusers/models/transformer_sd3.py | 9 ++------ 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 0df913ea7..4f63c544e 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -34,6 +34,7 @@ def __init__( self.dim = dim self.bias = True norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False) + context_norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=self.bias) self.silu = nn.Silu() self.linear1 = nn.LayerList([nn.Linear(self.dim, 6 * self.dim) for i in range(num_layers)]) @@ -41,12 +42,10 @@ def __init__( self.norm1 = nn.LayerList( [nn.LayerNorm(self.dim, epsilon=1e-6, **norm_elementwise_affine_kwargs) for i in range(num_layers)] ) - self.linear_context01 = nn.LayerList([nn.Linear(self.dim, 6 * self.dim) for i in range(num_layers - 1)]) - self.norm1_context01 = nn.LayerList( - [nn.LayerNorm(self.dim, epsilon=1e-6, **norm_elementwise_affine_kwargs) for i in range(num_layers - 1)] + self.linear_context = nn.LayerList([nn.Linear(self.dim, (6 if i Date: Thu, 22 Aug 2024 03:42:17 +0000 Subject: [PATCH 17/65] delete context_pre_only=False --- ppdiffusers/ppdiffusers/models/simplified_sd3.py | 2 +- ppdiffusers/ppdiffusers/models/transformer_sd3.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 4f63c544e..5243dcaaf 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -25,7 +25,7 @@ class SimplifiedSD3(nn.Layer): def __init__( - self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int, context_pre_only=False + self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int ): super().__init__() diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 1a31a2b20..0e1a17b61 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -116,7 +116,6 @@ def __init__( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.inner_dim, - # context_pre_onl, ) self.simplified_sd3 = paddle.incubate.jit.inference( self.simplified_sd3, From 18777b62ca7a4836dd5f9cd3b5a14e7215cf18a9 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 22 Aug 2024 06:47:34 +0000 Subject: [PATCH 18/65] modify triton_optimize --- paddlemix/triton_ops/triton_ops.py | 7 ++++ .../ppdiffusers/models/simplified_sd3.py | 40 +++---------------- .../ppdiffusers/models/transformer_sd3.py | 4 +- 3 files changed, 14 insertions(+), 37 deletions(-) diff --git a/paddlemix/triton_ops/triton_ops.py b/paddlemix/triton_ops/triton_ops.py index bb6fead29..be22a9f01 100644 --- a/paddlemix/triton_ops/triton_ops.py +++ b/paddlemix/triton_ops/triton_ops.py @@ -1074,6 +1074,13 @@ def modulate(x, shift, scale): seq_size = x.shape[1] BLOCK_SIZE = 2048#min(1024, triton.next_power_of_2(N)) + + # baseline. + if os.getenv("INFERENCE_OPTIMIZE_TRITON") is None : + norm_hidden_states = paddle.nn.functional.layer_norm(x, [N], weight, bias, epsilon) + norm_hidden_states = norm_hidden_states * (1 + scale[:, None]) + shift[:, None] + return norm_hidden_states + op_name = "triton_adaptive_layer_norm" op_name += get_dtype_str(x.dtype) op_name += f"_{BLOCK_SIZE}_{weight_attr}_{bias_attr}" diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 5243dcaaf..6b761796b 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -28,8 +28,6 @@ def __init__( self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int ): super().__init__() - - self.context_pre_only = context_pre_only self.num_layers = num_layers self.dim = dim self.bias = True @@ -76,50 +74,23 @@ def forward(self, hidden_states, encoder_hidden_states, temb): context_pre_only = i == self.num_layers - 1 emb = self.linear1[i](temb_silu) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) - if os.getenv("INFERENCE_OPTIMIZE_TRITON"): - import paddlemix - - norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( - hidden_states, scale_msa, shift_msa, epsilon=1e-06 - ) - else: - norm_hidden_states = self.norm1[i](hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] + import paddlemix + norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_msa, shift_msa, epsilon=1e-06) emb = self.linear_context[i](temb_silu) if not context_pre_only: shift_msa, scale_msa, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = emb.chunk(6, axis=1) - - if os.getenv("INFERENCE_OPTIMIZE_TRITON"): - import paddlemix - norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( - encoder_hidden_states, scale_msa, shift_msa, epsilon=1e-06 - ) - else: - norm_encoder_hidden_states = ( - self.norm1_context[i](encoder_hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] - ) - + norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(encoder_hidden_states, scale_msa, shift_msa, epsilon=1e-06) else: scale, shift = paddle.chunk(emb, 2, axis=1) - - if os.getenv("INFERENCE_OPTIMIZE_TRITON"): - import paddlemix - norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( - encoder_hidden_states, scale, shift, bias=self.norm1_context[i].bias - ) - else: - norm_encoder_hidden_states = ( - self.norm1_context[i](encoder_hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] - ) + norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(encoder_hidden_states, scale, shift, bias=self.norm1_context[i].bias) if os.getenv("INFERENCE_OPTIMIZE_TRITON"): qkv = self.qkv[i](norm_hidden_states) eqkv = self.eqkv[i](norm_encoder_hidden_states) import paddlemix - q, k, v = paddlemix.triton_ops.split_concat(qkv, eqkv) q = q.reshape([2, -1, 24, 64]) k = k.reshape([2, -1, 24, 64]) @@ -159,7 +130,6 @@ def forward(self, hidden_states, encoder_hidden_states, temb): if os.getenv("INFERENCE_OPTIMIZE_TRITON"): import paddlemix - hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( hidden_states, attn_output, gate_msa, scale_mlp, shift_mlp, epsilon=1e-06 ) @@ -179,7 +149,6 @@ def forward(self, hidden_states, encoder_hidden_states, temb): if not context_pre_only: if os.getenv("INFERENCE_OPTIMIZE_TRITON"): import paddlemix - ( encoder_hidden_states, norm_encoder_hidden_states, @@ -201,4 +170,5 @@ def forward(self, hidden_states, encoder_hidden_states, temb): encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output else: encoder_hidden_states = None + return encoder_hidden_states, hidden_states diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 0e1a17b61..c791f0c59 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -443,8 +443,8 @@ def custom_modify_weight(cls, state_dict): paddle.concat( [ state_dict[f"simplified_sd3.q.{i}.bias"], - state_dict[f"simplified_sd3.q.{i}.bias"], - state_dict[f"simplified_sd3.q.{i}.bias"], + state_dict[f"simplified_sd3.k.{i}.bias"], + state_dict[f"simplified_sd3.v.{i}.bias"], ], axis=0, ) From 7a288e4adaef38489aa2c9a75088c1f95f263848 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 22 Aug 2024 07:30:36 +0000 Subject: [PATCH 19/65] modify triton_optimize --- paddlemix/triton_ops/triton_ops.py | 12 ++- .../ppdiffusers/models/simplified_sd3.py | 84 +++++++++---------- 2 files changed, 49 insertions(+), 47 deletions(-) diff --git a/paddlemix/triton_ops/triton_ops.py b/paddlemix/triton_ops/triton_ops.py index be22a9f01..d75385050 100644 --- a/paddlemix/triton_ops/triton_ops.py +++ b/paddlemix/triton_ops/triton_ops.py @@ -839,6 +839,13 @@ def paddle_fused_adaLN(x, mha_out, gate, hidd, scale, shift, weight, bias, epsil seq_size = x.shape[1] N_npo2 = triton.next_power_of_2(N) + # baseline. + if os.getenv("INFERENCE_OPTIMIZE_TRITON") is None: + resi_out_paddle = mha_out * gate_msa.unsqueeze(axis=1) + x + norm_hidden_states = paddle.nn.functional.layer_norm(resi_out_paddle, [N], weight, bias, epsilon) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + return resi_out_paddle, norm_hidden_states + op_name = "triton_fused_adaLN_scale_residual" op_name += get_dtype_str(x.dtype) op_name += f"_{N_npo2}_{weight_attr}_{bias_attr}" @@ -1072,11 +1079,10 @@ def modulate(x, shift, scale): M = x.shape[0] * x.shape[1] N = x.shape[2] seq_size = x.shape[1] - BLOCK_SIZE = 2048#min(1024, triton.next_power_of_2(N)) + BLOCK_SIZE = 2048 # min(1024, triton.next_power_of_2(N)) - # baseline. - if os.getenv("INFERENCE_OPTIMIZE_TRITON") is None : + if os.getenv("INFERENCE_OPTIMIZE_TRITON") is None: norm_hidden_states = paddle.nn.functional.layer_norm(x, [N], weight, bias, epsilon) norm_hidden_states = norm_hidden_states * (1 + scale[:, None]) + shift[:, None] return norm_hidden_states diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 6b761796b..9a7bee143 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -24,9 +24,7 @@ class SimplifiedSD3(nn.Layer): - def __init__( - self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int - ): + def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): super().__init__() self.num_layers = num_layers self.dim = dim @@ -36,13 +34,25 @@ def __init__( self.silu = nn.Silu() self.linear1 = nn.LayerList([nn.Linear(self.dim, 6 * self.dim) for i in range(num_layers)]) - # self.linear1 = nn.Linear(self.dim, 6 * self.dim * 24) self.norm1 = nn.LayerList( [nn.LayerNorm(self.dim, epsilon=1e-6, **norm_elementwise_affine_kwargs) for i in range(num_layers)] ) - self.linear_context = nn.LayerList([nn.Linear(self.dim, (6 if i Date: Thu, 22 Aug 2024 08:47:26 +0000 Subject: [PATCH 20/65] modify triton_optimize --- .../ppdiffusers/models/simplified_sd3.py | 123 ++++++------------ .../ppdiffusers/models/transformer_sd3.py | 8 +- 2 files changed, 46 insertions(+), 85 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 9a7bee143..ebc999b2c 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -12,82 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. -# import math -import os - import paddle import paddle.nn.functional as F from paddle import nn -from paddle.incubate.nn.functional import fused_linear, fused_linear_activation - -optimize = True - class SimplifiedSD3(nn.Layer): def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): super().__init__() self.num_layers = num_layers self.dim = dim - self.bias = True - norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False) - context_norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=self.bias) self.silu = nn.Silu() self.linear1 = nn.LayerList([nn.Linear(self.dim, 6 * self.dim) for i in range(num_layers)]) - self.norm1 = nn.LayerList( - [nn.LayerNorm(self.dim, epsilon=1e-6, **norm_elementwise_affine_kwargs) for i in range(num_layers)] - ) self.linear_context = nn.LayerList( [nn.Linear(self.dim, (6 if i < num_layers - 1 else 2) * self.dim) for i in range(num_layers)] ) - self.norm1_context = nn.LayerList( - [ - nn.LayerNorm( - self.dim, - epsilon=1e-6, - **( - norm_elementwise_affine_kwargs - if i < num_layers - 1 - else context_norm_elementwise_affine_kwargs - ), - ) - for i in range(num_layers) - ] - ) - self.q = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) - self.k = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) - self.v = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) + + self.norm_last_context = nn.LayerNorm(self.dim, epsilon=1e-6, weight_attr=False, bias_attr=True) + self.qkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) - self.eq = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) - self.ek = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) - self.ev = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) self.eqkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) self.to_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) self.to_add_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers - 1)]) - self.ffn_norm = nn.LayerList( - [nn.LayerNorm(self.dim, weight_attr=False, bias_attr=False, epsilon=1e-6) for i in range(num_layers)] - ) self.ffn1 = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers)]) self.ffn2 = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers)]) - self.ffn_context_norm = nn.LayerList( - [nn.LayerNorm(self.dim, epsilon=1e-6, weight_attr=False, bias_attr=False) for i in range(num_layers - 1)] - ) - self.ffn_context1 = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers - 1)]) - self.ffn_context2 = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers - 1)]) + self.ffn1_context = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers - 1)]) + self.ffn2_context = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers - 1)]) def forward(self, hidden_states, encoder_hidden_states, temb): print("--------------------this is simplified_sd3------------------------") temb_silu = self.silu(temb) + + last_ffn_output = None + last_hidden_states = None + last_gate_mlp = None + for i in range(self.num_layers): context_pre_only = i == self.num_layers - 1 emb = self.linear1[i](temb_silu) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) + import paddlemix - norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( - hidden_states, scale_msa, shift_msa, epsilon=1e-06 - ) + if last_ffn_output is None: + norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( + hidden_states, scale_msa, shift_msa, epsilon=1e-06 + ) + else: + hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( + last_hidden_states, last_ffn_output, last_gate_mlp, scale_msa, shift_msa, epsilon=1e-06 + ) emb = self.linear_context[i](temb_silu) if not context_pre_only: @@ -98,36 +73,15 @@ def forward(self, hidden_states, encoder_hidden_states, temb): else: scale, shift = paddle.chunk(emb, 2, axis=1) norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( - encoder_hidden_states, scale, shift, bias=self.norm1_context[i].bias + encoder_hidden_states, scale, shift, bias=self.norm_last_context.bias ) - if os.getenv("INFERENCE_OPTIMIZE_TRITON"): - qkv = self.qkv[i](norm_hidden_states) - eqkv = self.eqkv[i](norm_encoder_hidden_states) - - import paddlemix - - q, k, v = paddlemix.triton_ops.split_concat(qkv, eqkv) - q = q.reshape([2, -1, 24, 64]) - k = k.reshape([2, -1, 24, 64]) - v = v.reshape([2, -1, 24, 64]) - else: - # residual = norm_hidden_states - # q = self.q[i](norm_hidden_states) - # k = self.k[i](norm_hidden_states) - # v = self.v[i](norm_hidden_states) - qkv = self.qkv[i](norm_hidden_states) - q, k, v = paddle.split(qkv, axis=2, num_or_sections=3) - - # eq = self.eq[i](norm_encoder_hidden_states) - # ek = self.ek[i](norm_encoder_hidden_states) - # ev = self.ev[i](norm_encoder_hidden_states) - eqkv = self.eqkv[i](norm_encoder_hidden_states) - eq, ek, ev = paddle.split(eqkv, axis=2, num_or_sections=3) - - q = paddle.concat([q, eq], axis=1).reshape([2, -1, 24, 64]) - k = paddle.concat([k, ek], axis=1).reshape([2, -1, 24, 64]) - v = paddle.concat([v, ev], axis=1).reshape([2, -1, 24, 64]) + qkv = self.qkv[i](norm_hidden_states) + eqkv = self.eqkv[i](norm_encoder_hidden_states) + q, k, v = paddlemix.triton_ops.split_concat(qkv, eqkv) + q = q.reshape([2, -1, 24, 64]) + k = k.reshape([2, -1, 24, 64]) + v = v.reshape([2, -1, 24, 64]) norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False) norm_hidden_states1 = norm_hidden_states1.reshape([2, -1, self.dim]) @@ -147,23 +101,30 @@ def forward(self, hidden_states, encoder_hidden_states, temb): hidden_states, attn_output, gate_msa, scale_mlp, shift_mlp, epsilon=1e-06 ) - ff_output = self.ffn1[i](norm_hidden_states) - ff_output = F.gelu(ff_output, approximate=True) - ff_output = self.ffn2[i](ff_output) + # ffn1 + ffn_output = self.ffn1[i](norm_hidden_states) + ffn_output = F.gelu(ffn_output, approximate=True) + ffn_output = self.ffn2[i](ffn_output) - ff_output = gate_mlp.unsqueeze(1) * ff_output - hidden_states = hidden_states + ff_output + if context_pre_only: + ffn_output = gate_mlp.unsqueeze(1) * ffn_output + hidden_states = hidden_states + ffn_output + else: + last_ffn_output = ffn_output + last_hidden_states = hidden_states + last_gate_mlp = gate_mlp + # ffn2 if not context_pre_only: (encoder_hidden_states, norm_encoder_hidden_states,) = paddlemix.triton_ops.fused_adaLN_scale_residual( encoder_hidden_states, context_attn_output, c_gate_msa, c_scale_mlp, c_shift_mlp, epsilon=1e-06 ) - context_ff_output = self.ffn_context1[i](norm_encoder_hidden_states) - context_ff_output = F.gelu(context_ff_output, approximate=True) - context_ff_output = self.ffn_context2[i](context_ff_output) + context_ffn_output = self.ffn1_context[i](norm_encoder_hidden_states) + context_ffn_output = F.gelu(context_ffn_output, approximate=True) + context_ffn_output = self.ffn2_context[i](context_ffn_output) - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ffn_output else: encoder_hidden_states = None diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index c791f0c59..40e188d87 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -416,10 +416,10 @@ def custom_modify_weight(cls, state_dict): extra_map_sd3 = [ (f"to_add_out_linear.{i}.weight", f"{i}.attn.to_add_out.weight"), (f"to_add_out_linear.{i}.bias", f"{i}.attn.to_add_out.bias"), - (f"ffn_context1.{i}.weight", f"{i}.ff_context.net.0.proj.weight"), - (f"ffn_context1.{i}.bias", f"{i}.ff_context.net.0.proj.bias"), - (f"ffn_context2.{i}.weight", f"{i}.ff_context.net.2.weight"), - (f"ffn_context2.{i}.bias", f"{i}.ff_context.net.2.bias"), + (f"ffn1_context.{i}.weight", f"{i}.ff_context.net.0.proj.weight"), + (f"ffn1_context.{i}.bias", f"{i}.ff_context.net.0.proj.bias"), + (f"ffn2_context.{i}.weight", f"{i}.ff_context.net.2.weight"), + (f"ffn2_context.{i}.bias", f"{i}.ff_context.net.2.bias"), ] map_sd3 = base_map_sd3 + extra_map_sd3 From 95c9e47342884c2ac30e09e33ce43cf5e278f87a Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 22 Aug 2024 12:19:11 +0000 Subject: [PATCH 21/65] modify triton_fuse & Modifying performance issues affected by CUDA synchronization --- .../ppdiffusers/models/simplified_sd3.py | 42 +++++++++++++++---- .../scheduling_flow_match_euler_discrete.py | 12 +++--- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index ebc999b2c..fbdfc28a1 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -16,6 +16,7 @@ import paddle.nn.functional as F from paddle import nn + class SimplifiedSD3(nn.Layer): def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): super().__init__() @@ -27,9 +28,9 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.linear_context = nn.LayerList( [nn.Linear(self.dim, (6 if i < num_layers - 1 else 2) * self.dim) for i in range(num_layers)] ) - + self.norm_last_context = nn.LayerNorm(self.dim, epsilon=1e-6, weight_attr=False, bias_attr=True) - + self.qkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) self.eqkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) self.to_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) @@ -47,6 +48,10 @@ def forward(self, hidden_states, encoder_hidden_states, temb): last_hidden_states = None last_gate_mlp = None + last_context_ffn_output = None + last_context_hidden_states = None + last_context_gate_mlp = None + for i in range(self.num_layers): context_pre_only = i == self.num_layers - 1 @@ -67,13 +72,32 @@ def forward(self, hidden_states, encoder_hidden_states, temb): emb = self.linear_context[i](temb_silu) if not context_pre_only: shift_msa, scale_msa, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = emb.chunk(6, axis=1) - norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( - encoder_hidden_states, scale_msa, shift_msa, epsilon=1e-06 - ) + if last_context_ffn_output is None: + norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( + encoder_hidden_states, scale_msa, shift_msa, epsilon=1e-06 + ) + else: + ( + encoder_hidden_states, + norm_encoder_hidden_states, + ) = paddlemix.triton_ops.fused_adaLN_scale_residual( + last_context_hidden_states, + last_context_ffn_output, + last_context_gate_mlp, + scale_msa, + shift_msa, + epsilon=1e-06, + ) else: + # the last layer. scale, shift = paddle.chunk(emb, 2, axis=1) - norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( - encoder_hidden_states, scale, shift, bias=self.norm_last_context.bias + (encoder_hidden_states, norm_encoder_hidden_states,) = paddlemix.triton_ops.fused_adaLN_scale_residual( + last_context_hidden_states, + last_context_ffn_output, + last_context_gate_mlp, + scale, + shift, + epsilon=1e-06, ) qkv = self.qkv[i](norm_hidden_states) @@ -124,7 +148,9 @@ def forward(self, hidden_states, encoder_hidden_states, temb): context_ffn_output = F.gelu(context_ffn_output, approximate=True) context_ffn_output = self.ffn2_context[i](context_ffn_output) - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ffn_output + last_context_ffn_output = context_ffn_output + last_context_hidden_states = encoder_hidden_states + last_context_gate_mlp = c_gate_mlp else: encoder_hidden_states = None diff --git a/ppdiffusers/ppdiffusers/schedulers/scheduling_flow_match_euler_discrete.py b/ppdiffusers/ppdiffusers/schedulers/scheduling_flow_match_euler_discrete.py index ec71c1f51..218d6c133 100644 --- a/ppdiffusers/ppdiffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/ppdiffusers/ppdiffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -24,7 +24,6 @@ from ..utils.paddle_utils import randn_tensor from .scheduling_utils import SchedulerMixin - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -246,11 +245,12 @@ def step( sigma = self.sigmas[self.step_index] - gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + if s_churn == 0.0: + gamma = 0.0 + else: + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 - noise = randn_tensor( - model_output.shape, dtype=model_output.dtype, generator=generator - ) + noise = randn_tensor(model_output.shape, dtype=model_output.dtype, generator=generator) eps = noise * s_noise sigma_hat = sigma * (gamma + 1) @@ -283,4 +283,4 @@ def step( return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) def __len__(self): - return self.config.num_train_timesteps \ No newline at end of file + return self.config.num_train_timesteps From 84a9e7a3c02f3b8886edbfdc83a04a37dcba23f9 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Fri, 23 Aug 2024 07:59:00 +0000 Subject: [PATCH 22/65] modify transformer_sd3 if optimize_prigin --- ppdiffusers/ppdiffusers/models/transformer_sd3.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 40e188d87..30f85d0e6 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -1,4 +1,3 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -344,9 +343,14 @@ def forward( encoder_hidden_states = self.context_embedder(encoder_hidden_states) if self.inference_optimize: - hidden_states = self.simplified_sd3( + out = self.simplified_sd3( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ) + # this is for paddle inference. + if isinstance(out, paddle.Tensor): + hidden_states = out + else: + hidden_states = out[1] encoder_hidden_states = None elif self.inference_optimize_origin: From 9dd918df35bbcc026ef7f39bfdef6653e66d479a Mon Sep 17 00:00:00 2001 From: changwenbin Date: Fri, 23 Aug 2024 12:28:53 +0000 Subject: [PATCH 23/65] update vae triton_split --- paddlemix/triton_ops/__init__.py | 2 + paddlemix/triton_ops/triton_ops.py | 134 ++++++++++++++++++ ..._to_image_generation-stable_diffusion_3.py | 12 +- .../ppdiffusers/models/simplified_sd3.py | 7 +- .../pipeline_stable_diffusion_3.py | 34 +++-- 5 files changed, 171 insertions(+), 18 deletions(-) diff --git a/paddlemix/triton_ops/__init__.py b/paddlemix/triton_ops/__init__.py index 4c2f1691d..f10d9daaf 100644 --- a/paddlemix/triton_ops/__init__.py +++ b/paddlemix/triton_ops/__init__.py @@ -22,6 +22,7 @@ paddle_use_triton, rms_norm, split_concat, + triton_split, weight_only_int8, ) from .triton_utils import ( @@ -41,6 +42,7 @@ "get_dtype_str", "fused_rotary_emb", "split_concat", + "triton_split", ] except: pass diff --git a/paddlemix/triton_ops/triton_ops.py b/paddlemix/triton_ops/triton_ops.py index d75385050..070d8302b 100644 --- a/paddlemix/triton_ops/triton_ops.py +++ b/paddlemix/triton_ops/triton_ops.py @@ -1735,3 +1735,137 @@ def split_concat(x, y): outputs={"out0_tensor": out0, "out1_tensor": out1, "out2_tensor": out2}, ) return out0, out1, out2 + + +########################### triton split ############################### +triton_split_template = ( + """ +std::vector ${op_name}_func( + const paddle::Tensor &x, + const std::vector num_or_sections, + const int64_t axis) { + + int output_batch = x.dims()[0]; + int output_seq0 = num_or_sections[0]; + int output_seq1 = num_or_sections[1]; + int output_hidden = x.dims()[2]; + + auto out0_tensor = paddle::empty({output_batch, output_seq0, output_hidden}, x.dtype(), x.place()); + auto out1_tensor = paddle::empty({output_batch, output_seq1, output_hidden}, x.dtype(), x.place()); + + auto out0 = get_tensor_ptr(out0_tensor); + auto out1 = get_tensor_ptr(out1_tensor); + + auto input = get_tensor_ptr(x); + + auto run_stream = out0_tensor.stream(); + +""" + + tune_and_invoke_part + + """ + return {out0_tensor, out1_tensor}; +} + +std::vector> ${op_name}_InferShape( + const std::vector& A_shape) { + + std::vector out_shape0 = {A_shape[0], 1024, A_shape[2]}; + std::vector out_shape1 = {A_shape[0], 154, A_shape[2]}; + + return {out_shape0, out_shape1}; +} + +std::vector ${op_name}_InferDtype(const paddle::DataType& A_dtype) { + return {A_dtype, A_dtype}; +} + +PD_BUILD_OP(${op_name}) + .Inputs({"x"}) + .Outputs({"out0_tensor", "out1_tensor"}) + .SetKernelFn(PD_KERNEL(${op_name}_func)) + .Attrs({"num_or_sections: std::vector", "axis: int64_t"}) + .SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype)) + .SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape)); +""" +) + + +@paddle_use_triton( + custom_op_template=triton_split_template, + key=["1"], +) +def triton_split_kernel( + out0, + out1, + input, + output_seq0, + output_seq1, + output_batch, + output_hidden, + BLOCK_SIZE: tl.constexpr, +): + batch = tl.program_id(axis=0) + out_row = tl.program_id(axis=1) + read_ptr = out_row * output_hidden + batch * (output_seq0 + output_seq1) * output_hidden + input + + read_offsets = tl.arange(0, BLOCK_SIZE) + mask = read_offsets < output_hidden + read_data = tl.load(read_ptr + read_offsets, mask=mask) + + if out_row < output_seq0: + write_ptr = batch * output_seq0 * output_hidden + out_row * output_hidden + out0 + read_offsets + else: + write_ptr = batch * output_seq1 * output_hidden + (out_row - output_seq0) * output_hidden + out1 + read_offsets + + tl.store(write_ptr, read_data, mask=mask) + + +def triton_split(x, num_or_sections=[-1, -1], axis=1): + assert len(x.shape) == 3 + output_batch = x.shape[0] + output_seq0 = num_or_sections[0] + output_seq1 = num_or_sections[1] + output_hidden = x.shape[2] + + BLOCK_SIZE = triton.next_power_of_2(output_hidden) + op_name = "triton_split" + op_name += get_dtype_str(x.dtype) + op_name += f"_{BLOCK_SIZE}" + + if op_name not in OpProtoHolder.instance().op_proto_map.keys(): + out0 = paddle.empty(shape=[output_batch, output_seq0, output_hidden], dtype=x.dtype) + out1 = paddle.empty(shape=[output_batch, output_seq1, output_hidden], dtype=x.dtype) + grid = ("output_batch", "output_seq0+output_seq1") + + triton_split_kernel[(op_name, grid)]( + out0, out1, x, output_seq0, output_seq1, output_batch, output_hidden, BLOCK_SIZE=2048 + ) + + if in_dynamic_or_pir_mode(): + print(f"== we are in dynamic mode, op_name: {op_name}") + outs = _C_ops._run_custom_op( + op_name, + x, + num_or_sections, + axis, + ) + return outs[0], outs[1] + else: + print(f"== we are in dynamic to static mode, op_name: {op_name}") + helper = LayerHelper(op_name, **locals()) + inputs = { + "x": x, + } + out0 = helper.create_variable_for_type_inference(dtype=x.dtype) + out1 = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type=op_name, + inputs=inputs, + attrs={ + "num_or_sections": num_or_sections, + "axis": axis, + }, + outputs={"out0_tensor": out0, "out1_tensor": out1}, + ) + return out0, out1 diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index 107d132e1..a32dd9da0 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -72,6 +72,16 @@ def parse_args(): # from_hf_hub=True, # from_diffusers=True, ) + +# for vae model +pipe.vae.decode = paddle.incubate.jit.inference( + pipe.vae.decode, + save_model_dir="./tmp/vae_static_models", + cache_static_model=False, + enable_new_ir=True, + exp_enable_use_cutlass=True, +) + generator = paddle.Generator().manual_seed(42) prompt = "A cat holding a sign that says hello world" @@ -91,7 +101,7 @@ def parse_args(): generator=generator, ).images[0] - repeat_times = 5 + repeat_times = 6 sumtime = 0.0 for i in range(repeat_times): paddle.device.synchronize() diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index fbdfc28a1..3901b3f54 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -109,9 +109,12 @@ def forward(self, hidden_states, encoder_hidden_states, temb): norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False) norm_hidden_states1 = norm_hidden_states1.reshape([2, -1, self.dim]) + # attn_output, context_attn_output = paddle.split( + # norm_hidden_states1, num_or_sections=[hidden_states.shape[1], encoder_hidden_states.shape[1]], axis=1 + # ) - attn_output, context_attn_output = paddle.split( - norm_hidden_states1, num_or_sections=[hidden_states.shape[1], encoder_hidden_states.shape[1]], axis=1 + attn_output, context_attn_output = paddlemix.triton_ops.triton_split( + norm_hidden_states1, num_or_sections=[1024, 154], axis=1 ) attn_output = paddle.nn.functional.linear( diff --git a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 732507924..65004692b 100644 --- a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -17,12 +17,12 @@ from typing import Any, Callable, Dict, List, Optional, Union import paddle -from ppdiffusers.transformers import ( + +from ppdiffusers.transformers import ( # T5TokenizerFast, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, - # T5TokenizerFast, - T5Tokenizer + T5Tokenizer, ) from ...image_processor import VaeImageProcessor @@ -30,15 +30,11 @@ from ...models.autoencoder_kl import AutoencoderKL from ...models.transformer_sd3 import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import ( - logging, - replace_example_docstring, -) +from ...utils import logging, replace_example_docstring from ...utils.paddle_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import StableDiffusion3PipelineOutput - logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ @@ -114,7 +110,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class StableDiffusion3Pipeline(DiffusionPipeline, FromSingleFileMixin): # SD3LoraLoaderMixin +class StableDiffusion3Pipeline(DiffusionPipeline, FromSingleFileMixin): # SD3LoraLoaderMixin r""" Args: @@ -224,7 +220,7 @@ def _get_t5_prompt_embeds( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) - # breakpoint() + prompt_embeds = self.text_encoder_3(text_input_ids)[0] dtype = self.text_encoder_3.dtype @@ -383,7 +379,9 @@ def encode_prompt( ) clip_prompt_embeds = paddle.nn.functional.pad( - clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]), data_format='NCL', + clip_prompt_embeds, + (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]), + data_format="NCL", ) prompt_embeds = paddle.concat([clip_prompt_embeds, t5_prompt_embed], axis=-2) @@ -430,12 +428,14 @@ def encode_prompt( negative_clip_prompt_embeds = paddle.concat([negative_prompt_embed, negative_prompt_2_embed], axis=-1) t5_negative_prompt_embed = self._get_t5_prompt_embeds( - prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, ) negative_clip_prompt_embeds = paddle.nn.functional.pad( negative_clip_prompt_embeds, - (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), data_format='NCL', + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + data_format="NCL", ) negative_prompt_embeds = paddle.concat([negative_clip_prompt_embeds, t5_negative_prompt_embed], axis=-2) @@ -834,7 +834,11 @@ def __call__( else: latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - image = self.vae.decode(latents, return_dict=False)[0] + image_out = self.vae.decode(latents, return_dict=False) + if isinstance(image_out, paddle.Tensor): + image = image_out + else: + image = image_out[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models @@ -843,4 +847,4 @@ def __call__( if not return_dict: return (image,) - return StableDiffusion3PipelineOutput(images=image) \ No newline at end of file + return StableDiffusion3PipelineOutput(images=image) From 3a0b7e1a7d345d44741a7f11a885bf9153bc7a06 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Mon, 26 Aug 2024 11:50:41 +0000 Subject: [PATCH 24/65] vae T5 d2s & transformer forward d2s --- .../ppdiffusers/models/autoencoder_kl.py | 4 +++- ppdiffusers/ppdiffusers/models/embeddings.py | 20 ++++++++++------ .../ppdiffusers/models/transformer_sd3.py | 24 ++++++++++++------- .../pipeline_stable_diffusion_3.py | 20 ++++++++++++---- .../ppdiffusers/transformers/t5/modeling.py | 13 ++++++++++ 5 files changed, 61 insertions(+), 20 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/autoencoder_kl.py b/ppdiffusers/ppdiffusers/models/autoencoder_kl.py index d8fda325a..9f975ea22 100644 --- a/ppdiffusers/ppdiffusers/models/autoencoder_kl.py +++ b/ppdiffusers/ppdiffusers/models/autoencoder_kl.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Dict, Optional, Tuple, Union import paddle @@ -88,6 +89,7 @@ def __init__( use_quant_conv: bool = True, use_post_quant_conv: bool = True, ): + os.environ["USE_PPXFORMERS"] = "False" super().__init__() # if down_block_out_channels not given, we will use block_out_channels _down_block_out_channels = block_out_channels if down_block_out_channels is None else down_block_out_channels @@ -116,7 +118,7 @@ def __init__( norm_num_groups=norm_num_groups, act_fn=act_fn, ) - + del os.environ["USE_PPXFORMERS"] self.quant_conv = nn.Conv2D(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None self.post_quant_conv = nn.Conv2D(latent_channels, latent_channels, 1) if use_post_quant_conv else None diff --git a/ppdiffusers/ppdiffusers/models/embeddings.py b/ppdiffusers/ppdiffusers/models/embeddings.py index 0446fcbd9..de67bd792 100644 --- a/ppdiffusers/ppdiffusers/models/embeddings.py +++ b/ppdiffusers/ppdiffusers/models/embeddings.py @@ -19,7 +19,7 @@ from paddle import nn from ..utils import USE_PEFT_BACKEND -from .activations import get_activation, FP32SiLU +from .activations import FP32SiLU, get_activation from .lora import LoRACompatibleLinear @@ -136,7 +136,7 @@ def __init__( interpolation_scale=1, add_pos_embed=True, data_format="NCHW", - pos_embed_max_size=None, # For SD3 cropping + pos_embed_max_size=None, # For SD3 cropping ): super().__init__() @@ -147,7 +147,12 @@ def __init__( self.data_format = data_format self.proj = nn.Conv2D( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias_attr=bias, data_format=data_format, + in_channels, + embed_dim, + kernel_size=(patch_size, patch_size), + stride=patch_size, + bias_attr=bias, + data_format=data_format, ) if layer_norm: norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False) @@ -178,6 +183,7 @@ def __init__( self.register_buffer( "pos_embed", paddle.to_tensor(pos_embed).cast("float32").unsqueeze(0), persistable=persistent ) + def cropped_pos_embed(self, height, width): """Crops positional embeddings for SD3 compatibility.""" if self.pos_embed_max_size is None: @@ -215,7 +221,7 @@ def forward(self, latent): if self.data_format == "NCHW": latent = latent.flatten(2).transpose([0, 2, 1]) # BCHW -> BNC else: - latent = latent.flatten(1, 2) # BHWC -> BNC + latent = latent.flatten(1, 2) # BHWC -> BNC if self.layer_norm: latent = self.norm(latent) @@ -521,7 +527,6 @@ def forward(self, image_embeds: paddle.Tensor): return image_embeds - class CombinedTimestepTextProjEmbeddings(nn.Layer): def __init__(self, embedding_dim, pooled_projection_dim): super().__init__() @@ -532,7 +537,7 @@ def __init__(self, embedding_dim, pooled_projection_dim): def forward(self, timestep, pooled_projection): timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + timesteps_emb = self.timestep_embedder(timesteps_proj.cast(dtype=pooled_projection.dtype)) # (N, D) pooled_projections = self.text_embedder(pooled_projection) @@ -540,6 +545,7 @@ def forward(self, timestep, pooled_projection): return conditioning + class CombinedTimestepLabelEmbeddings(nn.Layer): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() @@ -906,4 +912,4 @@ def forward(self, caption): hidden_states = self.linear_1(caption) hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) - return hidden_states \ No newline at end of file + return hidden_states diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 30f85d0e6..f1920311a 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -116,13 +116,13 @@ def __init__( num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.inner_dim, ) - self.simplified_sd3 = paddle.incubate.jit.inference( - self.simplified_sd3, - enable_new_ir=True, - cache_static_model=False, - exp_enable_use_cutlass=True, - delete_pass_lists=["add_norm_fuse_pass"], - ) + # self.simplified_sd3 = paddle.incubate.jit.inference( + # self.simplified_sd3, + # enable_new_ir=True, + # cache_static_model=False, + # exp_enable_use_cutlass=True, + # delete_pass_lists=["add_norm_fuse_pass"], + # ) if self.inference_optimize_origin: self.sd3_origin_transformer = paddle.incubate.jit.inference( self.sd3_origin_transformer, @@ -292,6 +292,13 @@ def custom_forward(*inputs): ) return encoder_hidden_states, hidden_states + @paddle.incubate.jit.inference( + enable_new_ir=True, + cache_static_model=False, + save_model_dir="./tmp/sd3", + exp_enable_use_cutlass=True, + delete_pass_lists=["add_norm_fuse_pass"], + ) def forward( self, hidden_states: paddle.Tensor, @@ -375,7 +382,8 @@ def forward( hidden_states = hidden_states.reshape( shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) ) - hidden_states = paddle.einsum("nhwpqc->nchpwq", hidden_states) + # hidden_states = paddle.einsum("nhwpqc->nchpwq", hidden_states) + hidden_states = paddle.transpose(hidden_states, [0, 5, 1, 3, 2, 4]) output = hidden_states.reshape( shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) ) diff --git a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 65004692b..5ad77a04f 100644 --- a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -110,7 +110,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class StableDiffusion3Pipeline(DiffusionPipeline, FromSingleFileMixin): # SD3LoraLoaderMixin +class StableDiffusion3Pipeline(DiffusionPipeline, FromSingleFileMixin): r""" Args: @@ -221,7 +221,12 @@ def _get_t5_prompt_embeds( f" {self.tokenizer_max_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder_3(text_input_ids)[0] + outputs = self.text_encoder_3(text_input_ids) + # in order to d2s + if isinstance(outputs, paddle.Tensor): + prompt_embeds = outputs + else: + prompt_embeds = outputs[0] dtype = self.text_encoder_3.dtype prompt_embeds = prompt_embeds.astype(dtype=dtype) @@ -793,14 +798,19 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - noise_pred = self.transformer( + # in order to d2s + noise_pred_out = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, - )[0] + ) + if isinstance(noise_pred_out, paddle.Tensor): + noise_pred = noise_pred_out + else: + noise_pred = noise_pred_out[0] # perform guidance if self.do_classifier_free_guidance: @@ -834,6 +844,8 @@ def __call__( else: latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + # in order to d2s + latents = latents.cast("float32") image_out = self.vae.decode(latents, return_dict=False) if isinstance(image_out, paddle.Tensor): image = image_out diff --git a/ppdiffusers/ppdiffusers/transformers/t5/modeling.py b/ppdiffusers/ppdiffusers/transformers/t5/modeling.py index c3fd10cd3..e1e59a65d 100644 --- a/ppdiffusers/ppdiffusers/transformers/t5/modeling.py +++ b/ppdiffusers/ppdiffusers/transformers/t5/modeling.py @@ -1555,6 +1555,9 @@ def __init__(self, config: T5Config): # Initialize weights and apply final processing self.post_init() + # in order to d2s + del self.encoder + def get_input_embeddings(self): return self.shared @@ -1569,6 +1572,14 @@ def set_input_embeddings(self, new_embeddings): def get_encoder(self): return self.encoder + @paddle.incubate.jit.inference( + enable_new_ir=False, + cache_static_model=True, + save_model_dir="./tmp/T5", + with_trt=True, + trt_precision_mode="float16", + trt_use_static=True, + ) def forward( self, input_ids: Optional[paddle.Tensor] = None, @@ -1605,6 +1616,8 @@ def forward( return_dict=return_dict, ) + # there is a bug in dy2s + return encoder_output.last_hidden_state return encoder_output From 6d02d79d1a1704436cbd34ebf6af3d85e4e789a7 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Mon, 26 Aug 2024 11:57:30 +0000 Subject: [PATCH 25/65] update demo --- .../text_to_image_generation-stable_diffusion_3.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index a32dd9da0..a39ea77ed 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -77,9 +77,12 @@ def parse_args(): pipe.vae.decode = paddle.incubate.jit.inference( pipe.vae.decode, save_model_dir="./tmp/vae_static_models", - cache_static_model=False, - enable_new_ir=True, - exp_enable_use_cutlass=True, + cache_static_model=True, + enable_new_ir=False, + with_trt=True, + trt_precision_mode="float16", + trt_use_static=True, + collect_shape=False, ) generator = paddle.Generator().manual_seed(42) From 5d81b44743bdca13d8b95023a7f99a171996e86f Mon Sep 17 00:00:00 2001 From: changwenbin Date: Mon, 26 Aug 2024 16:10:54 +0000 Subject: [PATCH 26/65] update five model d2s --- ..._to_image_generation-stable_diffusion_3.py | 33 +++++++++++++++++-- .../pipeline_stable_diffusion_3.py | 17 ++++++++-- .../ppdiffusers/transformers/clip/modeling.py | 11 +++++-- .../ppdiffusers/transformers/t5/modeling.py | 27 +++++++-------- 4 files changed, 65 insertions(+), 23 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index a39ea77ed..a0402a7c3 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -73,16 +73,45 @@ def parse_args(): # from_diffusers=True, ) +pipe.text_encoder = paddle.incubate.jit.inference( + pipe.text_encoder, + save_model_dir="./tmp/text_encoder", + cache_static_model=True, + with_trt=True, + trt_precision_mode="float16", + trt_use_static=True, +) + +pipe.text_encoder_2 = paddle.incubate.jit.inference( + pipe.text_encoder_2, + save_model_dir="./tmp/text_encoder_2", + cache_static_model=True, + with_trt=True, + trt_precision_mode="float16", + trt_use_static=True, +) + + + +pipe.text_encoder_3 = paddle.incubate.jit.inference( + pipe.text_encoder_3, + save_model_dir="./tmp/text_encoder_3_T5", + cache_static_model=True, + with_trt=True, + trt_precision_mode="float16", + trt_use_static=True, +) + + + # for vae model pipe.vae.decode = paddle.incubate.jit.inference( pipe.vae.decode, save_model_dir="./tmp/vae_static_models", cache_static_model=True, - enable_new_ir=False, with_trt=True, trt_precision_mode="float16", trt_use_static=True, - collect_shape=False, ) generator = paddle.Generator().manual_seed(42) diff --git a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 5ad77a04f..4c4c146e6 100644 --- a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -272,14 +272,25 @@ def _get_clip_prompt_embeds( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) - prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True) + # prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True) + # pooled_prompt_embeds = prompt_embeds[0] + + # if clip_skip is None: + # prompt_embeds = prompt_embeds.hidden_states[-2] + # else: + # prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + + prompt_embeds = text_encoder(text_input_ids) pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds = prompt_embeds[1:][-2] else: - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + prompt_embeds = prompt_embeds[1:][-(clip_skip + 2)] + + pooled_prompt_embeds = pooled_prompt_embeds.astype(dtype=text_encoder.dtype) prompt_embeds = prompt_embeds.astype(dtype=self.text_encoder.dtype) _, seq_len, _ = prompt_embeds.shape diff --git a/ppdiffusers/ppdiffusers/transformers/clip/modeling.py b/ppdiffusers/ppdiffusers/transformers/clip/modeling.py index 1029b1da1..b73078c82 100644 --- a/ppdiffusers/ppdiffusers/transformers/clip/modeling.py +++ b/ppdiffusers/ppdiffusers/transformers/clip/modeling.py @@ -1256,9 +1256,9 @@ def forward( input_ids: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor] = None, position_ids: Optional[paddle.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + # output_attentions: Optional[bool] = None, + # output_hidden_states: Optional[bool] = None, + # return_dict: Optional[bool] = None, ) -> Union[Tuple, CLIPTextModelOutput]: r""" Returns: @@ -1276,6 +1276,11 @@ def forward( >>> outputs = model(**inputs) >>> text_embeds = outputs.text_embeds ```""" + + output_attentions = None + output_hidden_states = True + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict text_outputs = self.text_model( diff --git a/ppdiffusers/ppdiffusers/transformers/t5/modeling.py b/ppdiffusers/ppdiffusers/transformers/t5/modeling.py index e1e59a65d..db38b7a3b 100644 --- a/ppdiffusers/ppdiffusers/transformers/t5/modeling.py +++ b/ppdiffusers/ppdiffusers/transformers/t5/modeling.py @@ -40,6 +40,7 @@ Seq2SeqSequenceClassifierOutput, ) from paddlenlp.transformers.model_utils import register_base_model +from paddle.framework import in_dynamic_or_pir_mode from ...utils import logging from ..model_utils import ALL_LAYERNORM_LAYERS, PretrainedModel @@ -1572,22 +1573,11 @@ def set_input_embeddings(self, new_embeddings): def get_encoder(self): return self.encoder - @paddle.incubate.jit.inference( - enable_new_ir=False, - cache_static_model=True, - save_model_dir="./tmp/T5", - with_trt=True, - trt_precision_mode="float16", - trt_use_static=True, - ) def forward( self, input_ids: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor] = None, inputs_embeds: Optional[paddle.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, ) -> Union[Tuple[paddle.Tensor], BaseModelOutput]: r""" Returns: @@ -1605,6 +1595,11 @@ def forward( >>> outputs = model(input_ids=input_ids) >>> last_hidden_states = outputs.last_hidden_state ```""" + + output_attentions = None + output_hidden_states = None + return_dict = None + return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_output = self.encoder( @@ -1615,10 +1610,12 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - - # there is a bug in dy2s - return encoder_output.last_hidden_state - return encoder_output + + if in_dynamic_or_pir_mode(): + return encoder_output + else: + # there is a bug in dy2s,we fix it here. + return encoder_output.last_hidden_state class T5ForSequenceClassification(T5PretrainedModel): From 4bab1181a4380354af606c8fa3b9baafa9de2a80 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Tue, 27 Aug 2024 13:24:48 +0000 Subject: [PATCH 27/65] update SD3 clip T5 vae --- ..._to_image_generation-stable_diffusion_3.py | 6 +-- .../ppdiffusers/patches/paddle_patch.py | 2 +- .../pipeline_stable_diffusion_3.py | 37 ++++++++++--------- .../ppdiffusers/transformers/clip/modeling.py | 12 ++---- .../ppdiffusers/transformers/t5/modeling.py | 22 ++++++----- 5 files changed, 39 insertions(+), 40 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index a0402a7c3..0c911bb4c 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -68,7 +68,7 @@ def parse_args(): pipe = StableDiffusion3Pipeline.from_pretrained( "stabilityai/stable-diffusion-3-medium-diffusers", - paddle_dtype=paddle.float16, + paddle_dtype=paddle.float32, # from_hf_hub=True, # from_diffusers=True, ) @@ -92,7 +92,6 @@ def parse_args(): ) - pipe.text_encoder_3 = paddle.incubate.jit.inference( pipe.text_encoder_3, save_model_dir="./tmp/text_encoder_3_T5", @@ -103,7 +102,6 @@ def parse_args(): ) - # for vae model pipe.vae.decode = paddle.incubate.jit.inference( pipe.vae.decode, @@ -133,7 +131,7 @@ def parse_args(): generator=generator, ).images[0] - repeat_times = 6 + repeat_times = 10 sumtime = 0.0 for i in range(repeat_times): paddle.device.synchronize() diff --git a/ppdiffusers/ppdiffusers/patches/paddle_patch.py b/ppdiffusers/ppdiffusers/patches/paddle_patch.py index 6845ad632..b55cbf138 100644 --- a/ppdiffusers/ppdiffusers/patches/paddle_patch.py +++ b/ppdiffusers/ppdiffusers/patches/paddle_patch.py @@ -429,7 +429,7 @@ def scaled_dot_product_attention_( # (2) FLAG_USE_CUTLASS_V2 in yes, y, true, t, 1, use cutlass v2 use_cutlass_v2 = attn_mask is not None or str2bool(os.getenv("FLAG_USE_CUTLASS_V2", "no")) if not use_cutlass_v2: - with requires_grad_and_without_random(query, key, value): + with requires_grad_and_without_random(query, key, value, stop_gradient=False): output = memory_efficient_attention( query, key, diff --git a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 4c4c146e6..52287e905 100644 --- a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -19,6 +19,7 @@ import paddle from ppdiffusers.transformers import ( # T5TokenizerFast, + CLIPTextModelOutput, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, @@ -214,7 +215,7 @@ def _get_t5_prompt_embeds( text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pd").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not paddle.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" @@ -222,8 +223,8 @@ def _get_t5_prompt_embeds( ) outputs = self.text_encoder_3(text_input_ids) - # in order to d2s if isinstance(outputs, paddle.Tensor): + # NOTE:(changwenbin,zhoukangkang) this is for paddle.incubate.jit.inference prompt_embeds = outputs else: prompt_embeds = outputs[0] @@ -266,30 +267,31 @@ def _get_clip_prompt_embeds( text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pd").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not paddle.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) - # prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True) - # pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True) - # if clip_skip is None: - # prompt_embeds = prompt_embeds.hidden_states[-2] - # else: - # prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + if isinstance(prompt_embeds, CLIPTextModelOutput): + pooled_prompt_embeds = prompt_embeds[0] - - prompt_embeds = text_encoder(text_input_ids) - pooled_prompt_embeds = prompt_embeds[0] - - if clip_skip is None: - prompt_embeds = prompt_embeds[1:][-2] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + elif isinstance(prompt_embeds, list): + # NOTE:(changwenbin,zhoukangkang) this is for paddle.incubate.jit.inference + pooled_prompt_embeds = prompt_embeds[-1] + if clip_skip is None: + prompt_embeds = prompt_embeds[:-2][-2] + else: + prompt_embeds = prompt_embeds[:-2][-(clip_skip + 2)] else: - prompt_embeds = prompt_embeds[1:][-(clip_skip + 2)] + raise ValueError("ERRORS!") - pooled_prompt_embeds = pooled_prompt_embeds.astype(dtype=text_encoder.dtype) prompt_embeds = prompt_embeds.astype(dtype=self.text_encoder.dtype) @@ -818,6 +820,7 @@ def __call__( joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, ) + if isinstance(noise_pred_out, paddle.Tensor): noise_pred = noise_pred_out else: diff --git a/ppdiffusers/ppdiffusers/transformers/clip/modeling.py b/ppdiffusers/ppdiffusers/transformers/clip/modeling.py index b73078c82..11b5e50c8 100644 --- a/ppdiffusers/ppdiffusers/transformers/clip/modeling.py +++ b/ppdiffusers/ppdiffusers/transformers/clip/modeling.py @@ -1256,9 +1256,9 @@ def forward( input_ids: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor] = None, position_ids: Optional[paddle.Tensor] = None, - # output_attentions: Optional[bool] = None, - # output_hidden_states: Optional[bool] = None, - # return_dict: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[Tuple, CLIPTextModelOutput]: r""" Returns: @@ -1276,11 +1276,7 @@ def forward( >>> outputs = model(**inputs) >>> text_embeds = outputs.text_embeds ```""" - - output_attentions = None - output_hidden_states = True - return_dict = False - + return_dict = return_dict if return_dict is not None else self.config.use_return_dict text_outputs = self.text_model( diff --git a/ppdiffusers/ppdiffusers/transformers/t5/modeling.py b/ppdiffusers/ppdiffusers/transformers/t5/modeling.py index db38b7a3b..cf4df7d97 100644 --- a/ppdiffusers/ppdiffusers/transformers/t5/modeling.py +++ b/ppdiffusers/ppdiffusers/transformers/t5/modeling.py @@ -24,6 +24,7 @@ from paddle import nn from paddle.amp.auto_cast import amp_state from paddle.distributed import fleet +from paddle.framework import in_dynamic_or_pir_mode from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from paddlenlp.transformers.activations import ACT2FN from paddlenlp.transformers.conversion_utils import ( @@ -40,7 +41,6 @@ Seq2SeqSequenceClassifierOutput, ) from paddlenlp.transformers.model_utils import register_base_model -from paddle.framework import in_dynamic_or_pir_mode from ...utils import logging from ..model_utils import ALL_LAYERNORM_LAYERS, PretrainedModel @@ -1556,8 +1556,11 @@ def __init__(self, config: T5Config): # Initialize weights and apply final processing self.post_init() - # in order to d2s - del self.encoder + # NOTE:(changwenbin,zhoukangkang) + # When you use 'paddle.incubate.jit.inference' to reconstruct the model, + # if you have set 'cache_static_model=True', + # you can use 'del self.encoder' to reduce the global memory usage. + # del self.encoder def get_input_embeddings(self): return self.shared @@ -1578,6 +1581,9 @@ def forward( input_ids: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor] = None, inputs_embeds: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[Tuple[paddle.Tensor], BaseModelOutput]: r""" Returns: @@ -1595,11 +1601,7 @@ def forward( >>> outputs = model(input_ids=input_ids) >>> last_hidden_states = outputs.last_hidden_state ```""" - - output_attentions = None - output_hidden_states = None - return_dict = None - + return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_output = self.encoder( @@ -1610,11 +1612,11 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - + if in_dynamic_or_pir_mode(): return encoder_output else: - # there is a bug in dy2s,we fix it here. + # NOTE:(changwenbin,zhoukangkang)there is a bug in dy2s,we fix it here. return encoder_output.last_hidden_state From 5a14a0f86b83527e1e59c500a65ed4e65615624a Mon Sep 17 00:00:00 2001 From: changwenbin Date: Tue, 27 Aug 2024 13:27:14 +0000 Subject: [PATCH 28/65] update clip --- ppdiffusers/ppdiffusers/transformers/clip/modeling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ppdiffusers/ppdiffusers/transformers/clip/modeling.py b/ppdiffusers/ppdiffusers/transformers/clip/modeling.py index 11b5e50c8..1029b1da1 100644 --- a/ppdiffusers/ppdiffusers/transformers/clip/modeling.py +++ b/ppdiffusers/ppdiffusers/transformers/clip/modeling.py @@ -1276,7 +1276,6 @@ def forward( >>> outputs = model(**inputs) >>> text_embeds = outputs.text_embeds ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict text_outputs = self.text_model( From cd2ef0165addfc3ea7fc74dc3a8b37a32796ed51 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Tue, 27 Aug 2024 13:28:41 +0000 Subject: [PATCH 29/65] uodate T5 --- ppdiffusers/ppdiffusers/transformers/t5/modeling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ppdiffusers/ppdiffusers/transformers/t5/modeling.py b/ppdiffusers/ppdiffusers/transformers/t5/modeling.py index cf4df7d97..c005adb06 100644 --- a/ppdiffusers/ppdiffusers/transformers/t5/modeling.py +++ b/ppdiffusers/ppdiffusers/transformers/t5/modeling.py @@ -1592,7 +1592,6 @@ def forward( ```python >>> from ppdiffusers.transformers import AutoTokenizer, T5EncoderModel - >>> tokenizer = AutoTokenizer.from_pretrained("t5-small") >>> model = T5EncoderModel.from_pretrained("t5-small") >>> input_ids = tokenizer( From 624168cb0ba994ad9a2a4c11d691b2ca4f5c7c93 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Tue, 27 Aug 2024 13:30:17 +0000 Subject: [PATCH 30/65] uodate T5 --- ppdiffusers/ppdiffusers/transformers/t5/modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppdiffusers/ppdiffusers/transformers/t5/modeling.py b/ppdiffusers/ppdiffusers/transformers/t5/modeling.py index c005adb06..ede6e5784 100644 --- a/ppdiffusers/ppdiffusers/transformers/t5/modeling.py +++ b/ppdiffusers/ppdiffusers/transformers/t5/modeling.py @@ -1592,6 +1592,7 @@ def forward( ```python >>> from ppdiffusers.transformers import AutoTokenizer, T5EncoderModel + >>> tokenizer = AutoTokenizer.from_pretrained("t5-small") >>> model = T5EncoderModel.from_pretrained("t5-small") >>> input_ids = tokenizer( @@ -1600,7 +1601,6 @@ def forward( >>> outputs = model(input_ids=input_ids) >>> last_hidden_states = outputs.last_hidden_state ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_output = self.encoder( From b009b9fb18796537e3057a7ce59d36b9152ceea5 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Tue, 27 Aug 2024 13:35:55 +0000 Subject: [PATCH 31/65] update scheduling_flow_match_euler_discrete --- .../schedulers/scheduling_flow_match_euler_discrete.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppdiffusers/ppdiffusers/schedulers/scheduling_flow_match_euler_discrete.py b/ppdiffusers/ppdiffusers/schedulers/scheduling_flow_match_euler_discrete.py index 218d6c133..f26f997f7 100644 --- a/ppdiffusers/ppdiffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/ppdiffusers/ppdiffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -244,7 +244,7 @@ def step( sample = sample.cast(paddle.float32) sigma = self.sigmas[self.step_index] - + # NOTE:(changwenbin & zhoukangkang) when s_churn == 0.0,not need to compute gamma, Can reduce cuda synchronization if s_churn == 0.0: gamma = 0.0 else: From 8caa10ae6cce6a46095cd79a057a4d768cd9888f Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 28 Aug 2024 03:45:49 +0000 Subject: [PATCH 32/65] update normalization --- ppdiffusers/ppdiffusers/models/normalization.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ppdiffusers/ppdiffusers/models/normalization.py b/ppdiffusers/ppdiffusers/models/normalization.py index 10c0318b0..832c42f0c 100644 --- a/ppdiffusers/ppdiffusers/models/normalization.py +++ b/ppdiffusers/ppdiffusers/models/normalization.py @@ -62,8 +62,9 @@ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None): if num_embeddings is not None: self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) else: - # print("Using None") this + self.emb = None + self.silu = nn.Silu() self.linear = nn.Linear(embedding_dim, 6 * embedding_dim) norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False) From 377629aeb3b249e3048982c7ba767b1fc5d895ec Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 28 Aug 2024 03:48:37 +0000 Subject: [PATCH 33/65] update normalization --- ppdiffusers/ppdiffusers/models/normalization.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/normalization.py b/ppdiffusers/ppdiffusers/models/normalization.py index 832c42f0c..6b8131889 100644 --- a/ppdiffusers/ppdiffusers/models/normalization.py +++ b/ppdiffusers/ppdiffusers/models/normalization.py @@ -62,7 +62,6 @@ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None): if num_embeddings is not None: self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) else: - self.emb = None self.silu = nn.Silu() @@ -82,7 +81,6 @@ def forward( if self.emb is not None: emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) emb = self.linear(self.silu(emb)) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) if os.getenv("INFERENCE_OPTIMIZE_TRITON"): From 15fda4e24382b9ec47f936d4914f382d804922d8 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 29 Aug 2024 09:43:15 +0000 Subject: [PATCH 34/65] update SD3 --- ..._to_image_generation-stable_diffusion_3.py | 2 +- .../ppdiffusers/models/attention_processor.py | 15 ++---------- .../pipeline_stable_diffusion_3.py | 24 +++++++++---------- 3 files changed, 15 insertions(+), 26 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index 0c911bb4c..71ecb8e0d 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -68,7 +68,7 @@ def parse_args(): pipe = StableDiffusion3Pipeline.from_pretrained( "stabilityai/stable-diffusion-3-medium-diffusers", - paddle_dtype=paddle.float32, + paddle_dtype=paddle.float16, # from_hf_hub=True, # from_diffusers=True, ) diff --git a/ppdiffusers/ppdiffusers/models/attention_processor.py b/ppdiffusers/ppdiffusers/models/attention_processor.py index f2c742cfc..925c17085 100644 --- a/ppdiffusers/ppdiffusers/models/attention_processor.py +++ b/ppdiffusers/ppdiffusers/models/attention_processor.py @@ -924,6 +924,7 @@ def __call__( **kwargs, ) -> paddle.Tensor: residual = hidden_states + input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape @@ -947,10 +948,6 @@ def __call__( encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - # print("hidden_states_q", encoder_hidden_states_query_proj) - # print("hidden_states_K", encoder_hidden_states_key_proj) - # print("hidden_states_V", encoder_hidden_states_value_proj) - # attention query = paddle.concat([query, encoder_hidden_states_query_proj], axis=1) key = paddle.concat([key, encoder_hidden_states_key_proj], axis=1) @@ -968,12 +965,6 @@ def __call__( hidden_states = hidden_states.reshape([batch_size, -1, attn.heads * head_dim]) hidden_states = hidden_states.astype(query.dtype) - # print("hidden_states",hidden_states) - # print("encoder_hidden_states",encoder_hidden_states) - # hidden_states.fill_(0.11189012) - # print("hidden_states", hidden_states) - # print("encoder_hidden_states",norm_encoder_hidden_states) - # Split the attention outputs. hidden_states, encoder_hidden_states = ( hidden_states[:, : residual.shape[1]], @@ -982,10 +973,8 @@ def __call__( # linear proj hidden_states = attn.to_out[0](hidden_states) - # print(type(attn.to_out[0])) - # print("hidden_states", hidden_states) - # dropout + # dropout hidden_states = attn.to_out[1](hidden_states) if not attn.context_pre_only: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) diff --git a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 52287e905..193b57b21 100644 --- a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -223,7 +223,7 @@ def _get_t5_prompt_embeds( ) outputs = self.text_encoder_3(text_input_ids) - if isinstance(outputs, paddle.Tensor): + if paddle.incubate.jit.is_inference_mode(self.text_encoder_3): # NOTE:(changwenbin,zhoukangkang) this is for paddle.incubate.jit.inference prompt_embeds = outputs else: @@ -275,14 +275,7 @@ def _get_clip_prompt_embeds( ) prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True) - if isinstance(prompt_embeds, CLIPTextModelOutput): - pooled_prompt_embeds = prompt_embeds[0] - - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] - else: - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - elif isinstance(prompt_embeds, list): + if paddle.incubate.jit.is_inference_mode(text_encoder): # NOTE:(changwenbin,zhoukangkang) this is for paddle.incubate.jit.inference pooled_prompt_embeds = prompt_embeds[-1] if clip_skip is None: @@ -290,7 +283,12 @@ def _get_clip_prompt_embeds( else: prompt_embeds = prompt_embeds[:-2][-(clip_skip + 2)] else: - raise ValueError("ERRORS!") + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] pooled_prompt_embeds = pooled_prompt_embeds.astype(dtype=text_encoder.dtype) prompt_embeds = prompt_embeds.astype(dtype=self.text_encoder.dtype) @@ -859,9 +857,11 @@ def __call__( latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor # in order to d2s - latents = latents.cast("float32") + if paddle.incubate.jit.is_inference_mode(self.vae.decode): + latents = latents.cast("float32") image_out = self.vae.decode(latents, return_dict=False) - if isinstance(image_out, paddle.Tensor): + if paddle.incubate.jit.is_inference_mode(self.vae.decode): + # NOTE:(changwenbin,zhoukangkang) this is for paddle.incubate.jit.inference image = image_out else: image = image_out[0] From 0e90eaf63dfdd857a56a1dfaa710c628a318c41c Mon Sep 17 00:00:00 2001 From: changwenbin Date: Mon, 2 Sep 2024 12:38:01 +0000 Subject: [PATCH 35/65] update cutlass gemm&fast_gelu --- .../inference/text_to_image_generation-stable_diffusion_3.py | 4 ++-- ppdiffusers/ppdiffusers/models/transformer_sd3.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index 71ecb8e0d..ae634da9d 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -148,9 +148,9 @@ def parse_args(): duringtime = endtime - starttime duringtime = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 sumtime += duringtime - print("The this end to end time : ", duringtime, "ms") + print("SD3 end to end time : ", duringtime, "ms") - print("The ave end to end time : ", sumtime / repeat_times, "ms") + print("SD3 ave end to end time : ", sumtime / repeat_times, "ms") cuda_mem_after_used = paddle.device.cuda.max_memory_allocated() / (1024**3) print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB") diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index f1920311a..bee4b151a 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -118,12 +118,13 @@ def __init__( ) # self.simplified_sd3 = paddle.incubate.jit.inference( # self.simplified_sd3, + # save_model_dir="./tmp/sd3", # enable_new_ir=True, # cache_static_model=False, # exp_enable_use_cutlass=True, # delete_pass_lists=["add_norm_fuse_pass"], # ) - if self.inference_optimize_origin: + elif self.inference_optimize_origin: self.sd3_origin_transformer = paddle.incubate.jit.inference( self.sd3_origin_transformer, enable_new_ir=True, @@ -294,7 +295,7 @@ def custom_forward(*inputs): @paddle.incubate.jit.inference( enable_new_ir=True, - cache_static_model=False, + cache_static_model=True, save_model_dir="./tmp/sd3", exp_enable_use_cutlass=True, delete_pass_lists=["add_norm_fuse_pass"], From c5bb81f534dda23127793ffa972d95a653244f6e Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 4 Sep 2024 08:53:49 +0000 Subject: [PATCH 36/65] update per-mmdit --- ppdiffusers/ppdiffusers/models/embeddings.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/embeddings.py b/ppdiffusers/ppdiffusers/models/embeddings.py index de67bd792..d97c4e293 100644 --- a/ppdiffusers/ppdiffusers/models/embeddings.py +++ b/ppdiffusers/ppdiffusers/models/embeddings.py @@ -53,11 +53,11 @@ def get_timestep_embedding( emb = scale * emb # concat sine and cosine embeddings - emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=-1) + emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1) # flip sine and cosine embeddings - if flip_sin_to_cos: - emb = paddle.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1) + # if flip_sin_to_cos: + # emb = paddle.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1) # zero pad if embedding_dim % 2 == 1: From 499752abcb4a3a109addd01c27a92ba699810482 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 4 Sep 2024 12:16:46 +0000 Subject: [PATCH 37/65] update triton op split_concat --- paddlemix/triton_ops/triton_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlemix/triton_ops/triton_ops.py b/paddlemix/triton_ops/triton_ops.py index f8a918df7..41fc718bf 100644 --- a/paddlemix/triton_ops/triton_ops.py +++ b/paddlemix/triton_ops/triton_ops.py @@ -1707,7 +1707,7 @@ def split_concat(x, y): grid = ("3", "batch", "seq_qkv + seq_eqkv") split_concat_kernel[(op_name, grid)]( - out0, out1, out2, x, y, batch, seq_qkv, seq_eqkv, ouput_hidden, BLOCK_SIZE=2048 + out0, out1, out2, x, y, batch, seq_qkv, seq_eqkv, ouput_hidden, BLOCK_SIZE=BLOCK_SIZE ) if in_dynamic_or_pir_mode(): From 1084f4aeb1932042b53b014f92bf91ca63501e36 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 5 Sep 2024 10:34:44 +0000 Subject: [PATCH 38/65] update embeddings --- ppdiffusers/ppdiffusers/models/embeddings.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/embeddings.py b/ppdiffusers/ppdiffusers/models/embeddings.py index d97c4e293..0369e3bb0 100644 --- a/ppdiffusers/ppdiffusers/models/embeddings.py +++ b/ppdiffusers/ppdiffusers/models/embeddings.py @@ -52,12 +52,11 @@ def get_timestep_embedding( # scale embeddings emb = scale * emb - # concat sine and cosine embeddings - emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1) - # flip sine and cosine embeddings - # if flip_sin_to_cos: - # emb = paddle.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1) + if flip_sin_to_cos: + emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1) + else: + emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=-1) # zero pad if embedding_dim % 2 == 1: From fa8455939c499573a6b5712f0327c96c09ca7435 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Fri, 6 Sep 2024 06:45:31 +0000 Subject: [PATCH 39/65] recovery --- .../ppdiffusers/models/attention_processor.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/attention_processor.py b/ppdiffusers/ppdiffusers/models/attention_processor.py index 925c17085..c93c55ae6 100644 --- a/ppdiffusers/ppdiffusers/models/attention_processor.py +++ b/ppdiffusers/ppdiffusers/models/attention_processor.py @@ -906,7 +906,6 @@ def __call__( return hidden_states - class JointAttnProcessor2_5: """Attention processor used typically in processing the SD3-like self-attention projections.""" @@ -932,9 +931,7 @@ def __call__( context_input_ndim = encoder_hidden_states.ndim if context_input_ndim == 4: batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.reshape([batch_size, channel, height * width]).transpose( - [0, 2, 1] - ) + encoder_hidden_states = encoder_hidden_states.reshape([batch_size, channel, height * width]).transpose([0, 2, 1]) batch_size = encoder_hidden_states.shape[0] @@ -973,7 +970,6 @@ def __call__( # linear proj hidden_states = attn.to_out[0](hidden_states) - # dropout hidden_states = attn.to_out[1](hidden_states) if not attn.context_pre_only: @@ -982,9 +978,7 @@ def __call__( if input_ndim == 4: hidden_states = hidden_states.transpose([0, 1, 3, 2]).reshape([batch_size, channel, height, width]) if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose([0, 1, 3, 2]).reshape( - [batch_size, channel, height, width] - ) + encoder_hidden_states = encoder_hidden_states.transpose([0, 1, 3, 2]).reshape([batch_size, channel, height, width]) return hidden_states, encoder_hidden_states @@ -1015,9 +1009,7 @@ def __call__( context_input_ndim = encoder_hidden_states.ndim if context_input_ndim == 4: batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.reshape([batch_size, channel, height * width]).transpose( - [0, 2, 1] - ) + encoder_hidden_states = encoder_hidden_states.reshape([batch_size, channel, height * width]).transpose([0, 2, 1]) batch_size = encoder_hidden_states.shape[0] @@ -1068,13 +1060,10 @@ def __call__( if input_ndim == 4: hidden_states = hidden_states.transpose([0, 1, 3, 2]).reshape([batch_size, channel, height, width]) if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose([0, 1, 3, 2]).reshape( - [batch_size, channel, height, width] - ) + encoder_hidden_states = encoder_hidden_states.transpose([0, 1, 3, 2]).reshape([batch_size, channel, height, width]) return hidden_states, encoder_hidden_states - class XFormersAttnAddedKVProcessor: r""" Processor for implementing memory efficient attention using xFormers. From 27c62f9da4107a11bd9d0c1019b6872853671a49 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Fri, 6 Sep 2024 07:18:22 +0000 Subject: [PATCH 40/65] recovery --- ppdiffusers/ppdiffusers/models/attention.py | 45 +++++++-------------- 1 file changed, 14 insertions(+), 31 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/attention.py b/ppdiffusers/ppdiffusers/models/attention.py index 88a1abaf0..8b5a9d027 100644 --- a/ppdiffusers/ppdiffusers/models/attention.py +++ b/ppdiffusers/ppdiffusers/models/attention.py @@ -11,12 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Any, Dict, Optional import paddle -import paddle.nn.functional as F from paddle import nn +import paddle.nn.functional as F from ..utils import USE_PEFT_BACKEND from ..utils.paddle_utils import maybe_allow_in_graph @@ -93,7 +92,6 @@ def forward(self, x: paddle.Tensor, objs: paddle.Tensor) -> paddle.Tensor: return x - @maybe_allow_in_graph class JointTransformerBlock(nn.Layer): r""" @@ -114,6 +112,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" self.norm1 = AdaLayerNormZero(dim) + if context_norm_type == "ada_norm_continous": self.norm1_context = AdaLayerNormContinuous( dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm" @@ -162,7 +161,9 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): self._chunk_size = chunk_size self._chunk_dim = dim - def forward(self, hidden_states: paddle.Tensor, encoder_hidden_states: paddle.Tensor, temb: paddle.Tensor): + def forward( + self, hidden_states: paddle.Tensor, encoder_hidden_states: paddle.Tensor, temb: paddle.Tensor + ): norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) if self.context_pre_only: @@ -174,23 +175,15 @@ def forward(self, hidden_states: paddle.Tensor, encoder_hidden_states: paddle.Te # Attention. attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, ) # Process attention outputs for the `hidden_states`. - if os.getenv("INFERENCE_OPTIMIZE_TRITON"): - import paddlemix - - hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( - hidden_states, attn_output, gate_msa, scale_mlp, shift_mlp, epsilon=1e-06 - ) - else: - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = hidden_states + attn_output - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) @@ -204,20 +197,11 @@ def forward(self, hidden_states: paddle.Tensor, encoder_hidden_states: paddle.Te if self.context_pre_only: encoder_hidden_states = None else: - if os.getenv("INFERENCE_OPTIMIZE_TRITON"): - import paddlemix - - encoder_hidden_states, norm_encoder_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( - encoder_hidden_states, context_attn_output, c_gate_msa, c_scale_mlp, c_shift_mlp, epsilon=1e-06 - ) - else: - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output - encoder_hidden_states = encoder_hidden_states + context_attn_output - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = ( - norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] - ) + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory context_ff_output = _chunked_feed_forward( @@ -229,7 +213,6 @@ def forward(self, hidden_states: paddle.Tensor, encoder_hidden_states: paddle.Te return encoder_hidden_states, hidden_states - @maybe_allow_in_graph class BasicTransformerBlock(nn.Layer): r""" From 9515323b3ad8aea0559f3021c46926470a7df155 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Tue, 10 Sep 2024 07:19:28 +0000 Subject: [PATCH 41/65] update normalization --- ppdiffusers/ppdiffusers/models/normalization.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/normalization.py b/ppdiffusers/ppdiffusers/models/normalization.py index 6b8131889..33e715675 100644 --- a/ppdiffusers/ppdiffusers/models/normalization.py +++ b/ppdiffusers/ppdiffusers/models/normalization.py @@ -82,16 +82,7 @@ def forward( emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) emb = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) - - if os.getenv("INFERENCE_OPTIMIZE_TRITON"): - import paddlemix - - x = paddlemix.triton_ops.adaptive_layer_norm( - x, scale_msa, shift_msa, self.norm.weight, self.norm.bias, epsilon=1e-06 - ) - else: - x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] - + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa, shift_mlp, scale_mlp, gate_mlp From d61e4cb05b07030055ce6ed5a41f9828aa304690 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Tue, 10 Sep 2024 07:51:25 +0000 Subject: [PATCH 42/65] update dtype --- .../text_to_image_generation-stable_diffusion_3.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index ae634da9d..7423ead69 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -53,6 +53,8 @@ def parse_args(): parser.add_argument("--height", type=int, default=512, help="Height of the generated image.") parser.add_argument("--width", type=int, default=512, help="Width of the generated image.") parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of inference steps.") + parser.add_argument("--dtype", type=str, default="float32", help="Inference data types.") + return parser.parse_args() @@ -64,13 +66,14 @@ def parse_args(): os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True" if args.inference_optimize_origin: os.environ["INFERENCE_OPTIMIZE_ORIGIN"] = "True" - +if args.dtype == "float32": + inference_dtype = paddle.float32 +elif args.dtype == "float16": + inference_dtype = paddle.float16 pipe = StableDiffusion3Pipeline.from_pretrained( "stabilityai/stable-diffusion-3-medium-diffusers", - paddle_dtype=paddle.float16, - # from_hf_hub=True, - # from_diffusers=True, + paddle_dtype=inference_dtype, ) pipe.text_encoder = paddle.incubate.jit.inference( From d961a4a6175eb954d11566718e33e9cc463555ab Mon Sep 17 00:00:00 2001 From: changwenbin Date: Tue, 10 Sep 2024 13:40:16 +0000 Subject: [PATCH 43/65] add SD3 doc --- ppdiffusers/deploy/sd3/README.md | 39 ++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 ppdiffusers/deploy/sd3/README.md diff --git a/ppdiffusers/deploy/sd3/README.md b/ppdiffusers/deploy/sd3/README.md new file mode 100644 index 000000000..be99e249c --- /dev/null +++ b/ppdiffusers/deploy/sd3/README.md @@ -0,0 +1,39 @@ +# Stable Diffusion 3 高性能推理 + +- Paddle Inference提供Stable Diffusion 3 模型高性能推理实现,推理性能提升70%+ +环境准备: +```shell +# 安装 triton并适配paddle +python -m pip install triton +python -m pip install git+https://github.com/zhoutianzi666/UseTritonInPaddle.git +python -c "import use_triton_in_paddle; use_triton_in_paddle.make_triton_compatible_with_paddle()" + +# 安装develop版本的paddle +python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ + +#指定Tensor-RT的lib路径 +export LD_LIBRARY_PATH=/your_TensorRT_dir//lib:$LD_LIBRARY_PATH + +#指定cutlass包路径 +export LD_LIBRARY_PATH=/your_dir/Paddle/paddle/phi/kernels/fusion/cutlass/conv2d/build:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=/your_dir/Paddle/paddle/phi/kernels/fusion/cutlass/gemm_epilogue/build:$LD_LIBRARY_PATH +``` + +高性能推理指令: +```shell +#step1: 生成FP32的TRT模型 +python text_to_image_generation-stable_diffusion_3.py --dtype float32 --height 512 --width 512 \ +--num-inference-steps 50 --inference_optimize 1 --inference_optimize_triton 1 \ +--benchmark 1 + +#step2: 执行FP16推理 +python text_to_image_generation-stable_diffusion_3.py --dtype float16 --height 512 --width 512 \ +--num-inference-steps 50 --inference_optimize 1 --inference_optimize_triton 1 \ +--benchmark 1 +``` + +- 在 NVIDIA A100-SXM4-40GB 上测试的性能如下: + +| Paddle Inference| OneDiff | PyTorch | Paddle 动态图 | +| --------------- | ------------ | ------------ | ------------ | +| 1.2 s | 1.58 s | 1.78 s | 4.202 s | From 48c66a6beb10c058065c04e5884d95f58abb27e9 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 18 Sep 2024 12:05:33 +0000 Subject: [PATCH 44/65] update SD3 doc --- ppdiffusers/deploy/sd3/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppdiffusers/deploy/sd3/README.md b/ppdiffusers/deploy/sd3/README.md index be99e249c..81a93cac6 100644 --- a/ppdiffusers/deploy/sd3/README.md +++ b/ppdiffusers/deploy/sd3/README.md @@ -21,7 +21,7 @@ export LD_LIBRARY_PATH=/your_dir/Paddle/paddle/phi/kernels/fusion/cutlass/gemm_e 高性能推理指令: ```shell -#step1: 生成FP32的TRT模型 +#step1: 生成FP32的paddle模型,同时根据Paddle模型生成FP16的TensorRT engine。 python text_to_image_generation-stable_diffusion_3.py --dtype float32 --height 512 --width 512 \ --num-inference-steps 50 --inference_optimize 1 --inference_optimize_triton 1 \ --benchmark 1 From 24c3c9e585261fd4d22e4adbf8df169eb0aaabe2 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 19 Sep 2024 02:42:00 +0000 Subject: [PATCH 45/65] add 'del transformer_blocks' --- ppdiffusers/ppdiffusers/models/transformer_sd3.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index bee4b151a..5731efe87 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -110,6 +110,8 @@ def __init__( ] ) if self.inference_optimize: + # we do not need self.transformer_blocks, del it to save memory. + del self.transformer_blocks self.simplified_sd3 = SimplifiedSD3( num_layers, dim=self.inner_dim, From 422f33bf0cb06c141c62cd87174da831eb085827 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 19 Sep 2024 06:16:55 +0000 Subject: [PATCH 46/65] update SD3 --- ppdiffusers/deploy/sd3/README.md | 14 +++++++------- ...xt_to_image_generation-stable_diffusion_3.py | 9 +++++++++ .../ppdiffusers/models/transformer_sd3.py | 17 +---------------- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/ppdiffusers/deploy/sd3/README.md b/ppdiffusers/deploy/sd3/README.md index 81a93cac6..791130bf0 100644 --- a/ppdiffusers/deploy/sd3/README.md +++ b/ppdiffusers/deploy/sd3/README.md @@ -11,22 +11,22 @@ python -c "import use_triton_in_paddle; use_triton_in_paddle.make_triton_compati # 安装develop版本的paddle python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ -#指定Tensor-RT的lib路径 +# 指定Tensor-RT的lib路径 export LD_LIBRARY_PATH=/your_TensorRT_dir//lib:$LD_LIBRARY_PATH -#指定cutlass包路径 +# 指定cutlass包路径 export LD_LIBRARY_PATH=/your_dir/Paddle/paddle/phi/kernels/fusion/cutlass/conv2d/build:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=/your_dir/Paddle/paddle/phi/kernels/fusion/cutlass/gemm_epilogue/build:$LD_LIBRARY_PATH ``` 高性能推理指令: ```shell -#step1: 生成FP32的paddle模型,同时根据Paddle模型生成FP16的TensorRT engine。 +# step1: 生成FP32的paddle模型,同时根据Paddle模型生成FP16的TensorRT engine。 python text_to_image_generation-stable_diffusion_3.py --dtype float32 --height 512 --width 512 \ --num-inference-steps 50 --inference_optimize 1 --inference_optimize_triton 1 \ --benchmark 1 -#step2: 执行FP16推理 +# step2: 执行FP16推理 python text_to_image_generation-stable_diffusion_3.py --dtype float16 --height 512 --width 512 \ --num-inference-steps 50 --inference_optimize 1 --inference_optimize_triton 1 \ --benchmark 1 @@ -34,6 +34,6 @@ python text_to_image_generation-stable_diffusion_3.py --dtype float16 --height - 在 NVIDIA A100-SXM4-40GB 上测试的性能如下: -| Paddle Inference| OneDiff | PyTorch | Paddle 动态图 | -| --------------- | ------------ | ------------ | ------------ | -| 1.2 s | 1.58 s | 1.78 s | 4.202 s | +| Paddle Inference| PyTorch | Paddle 动态图 | +| --------------- | ------------ | ------------ | +| 1.2 s | 1.78 s | 4.202 s | diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index 7423ead69..b1d638627 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -104,6 +104,15 @@ def parse_args(): trt_use_static=True, ) +pipe.transformer = paddle.incubate.jit.inference( + pipe.transformer, + save_model_dir="./tmp/sd3", + enable_new_ir=True, + cache_static_model=True, + exp_enable_use_cutlass=True, + delete_pass_lists=["add_norm_fuse_pass"], +) + # for vae model pipe.vae.decode = paddle.incubate.jit.inference( diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 5731efe87..a55384c19 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -118,14 +118,6 @@ def __init__( num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.inner_dim, ) - # self.simplified_sd3 = paddle.incubate.jit.inference( - # self.simplified_sd3, - # save_model_dir="./tmp/sd3", - # enable_new_ir=True, - # cache_static_model=False, - # exp_enable_use_cutlass=True, - # delete_pass_lists=["add_norm_fuse_pass"], - # ) elif self.inference_optimize_origin: self.sd3_origin_transformer = paddle.incubate.jit.inference( self.sd3_origin_transformer, @@ -295,13 +287,6 @@ def custom_forward(*inputs): ) return encoder_hidden_states, hidden_states - @paddle.incubate.jit.inference( - enable_new_ir=True, - cache_static_model=True, - save_model_dir="./tmp/sd3", - exp_enable_use_cutlass=True, - delete_pass_lists=["add_norm_fuse_pass"], - ) def forward( self, hidden_states: paddle.Tensor, @@ -385,7 +370,7 @@ def forward( hidden_states = hidden_states.reshape( shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) ) - # hidden_states = paddle.einsum("nhwpqc->nchpwq", hidden_states) + hidden_states = paddle.transpose(hidden_states, [0, 5, 1, 3, 2, 4]) output = hidden_states.reshape( shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) From c43d84fc9b633def3e3f7b83bad9d6d44a9bc585 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 19 Sep 2024 06:35:30 +0000 Subject: [PATCH 47/65] update SD3 --- .../ppdiffusers/models/transformer_sd3.py | 66 ++++++++----------- 1 file changed, 29 insertions(+), 37 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index a55384c19..03145a666 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -425,47 +425,39 @@ def custom_modify_weight(cls, state_dict): for to_, from_ in map_sd3: if "transformer_blocks." + from_ in state_dict: - state_dict["simplified_sd3." + to_] = paddle.assign(state_dict["transformer_blocks." + from_]) + state_dict["simplified_sd3." + to_] = state_dict["transformer_blocks." + from_] else: print(f"Warning!!: '{from_}' not found in state_dict") - state_dict[f"simplified_sd3.qkv.{i}.weight"] = paddle.assign( - paddle.concat( - [ - state_dict[f"simplified_sd3.q.{i}.weight"], - state_dict[f"simplified_sd3.k.{i}.weight"], - state_dict[f"simplified_sd3.v.{i}.weight"], - ], - axis=1, - ) + state_dict[f"simplified_sd3.qkv.{i}.weight"] = paddle.concat( + [ + state_dict[f"simplified_sd3.q.{i}.weight"], + state_dict[f"simplified_sd3.k.{i}.weight"], + state_dict[f"simplified_sd3.v.{i}.weight"], + ], + axis=1, ) - state_dict[f"simplified_sd3.qkv.{i}.bias"] = paddle.assign( - paddle.concat( - [ - state_dict[f"simplified_sd3.q.{i}.bias"], - state_dict[f"simplified_sd3.k.{i}.bias"], - state_dict[f"simplified_sd3.v.{i}.bias"], - ], - axis=0, - ) + state_dict[f"simplified_sd3.qkv.{i}.bias"] = paddle.concat( + [ + state_dict[f"simplified_sd3.q.{i}.bias"], + state_dict[f"simplified_sd3.k.{i}.bias"], + state_dict[f"simplified_sd3.v.{i}.bias"], + ], + axis=0, ) - state_dict[f"simplified_sd3.eqkv.{i}.weight"] = paddle.assign( - paddle.concat( - [ - state_dict[f"simplified_sd3.eq.{i}.weight"], - state_dict[f"simplified_sd3.ek.{i}.weight"], - state_dict[f"simplified_sd3.ev.{i}.weight"], - ], - axis=1, - ) + state_dict[f"simplified_sd3.eqkv.{i}.weight"] = paddle.concat( + [ + state_dict[f"simplified_sd3.eq.{i}.weight"], + state_dict[f"simplified_sd3.ek.{i}.weight"], + state_dict[f"simplified_sd3.ev.{i}.weight"], + ], + axis=1, ) - state_dict[f"simplified_sd3.eqkv.{i}.bias"] = paddle.assign( - paddle.concat( - [ - state_dict[f"simplified_sd3.eq.{i}.bias"], - state_dict[f"simplified_sd3.ek.{i}.bias"], - state_dict[f"simplified_sd3.ev.{i}.bias"], - ], - axis=0, - ) + state_dict[f"simplified_sd3.eqkv.{i}.bias"] = paddle.concat( + [ + state_dict[f"simplified_sd3.eq.{i}.bias"], + state_dict[f"simplified_sd3.ek.{i}.bias"], + state_dict[f"simplified_sd3.ev.{i}.bias"], + ], + axis=0, ) From 9d036249d119b90872b9d9f12724b6e9a378d6eb Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 19 Sep 2024 08:44:32 +0000 Subject: [PATCH 48/65] update Notes --- ppdiffusers/ppdiffusers/models/autoencoder_kl.py | 3 +++ ppdiffusers/ppdiffusers/models/transformer_sd3.py | 1 + ppdiffusers/ppdiffusers/transformers/t5/modeling.py | 2 +- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/ppdiffusers/ppdiffusers/models/autoencoder_kl.py b/ppdiffusers/ppdiffusers/models/autoencoder_kl.py index 9f975ea22..8ad84281c 100644 --- a/ppdiffusers/ppdiffusers/models/autoencoder_kl.py +++ b/ppdiffusers/ppdiffusers/models/autoencoder_kl.py @@ -89,6 +89,7 @@ def __init__( use_quant_conv: bool = True, use_post_quant_conv: bool = True, ): + # set USE_PPXFORMERS=False to avoid using ppxformers os.environ["USE_PPXFORMERS"] = "False" super().__init__() # if down_block_out_channels not given, we will use block_out_channels @@ -119,6 +120,8 @@ def __init__( act_fn=act_fn, ) del os.environ["USE_PPXFORMERS"] + # del set USE_PPXFORMERS=False to Restore Defaults + self.quant_conv = nn.Conv2D(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None self.post_quant_conv = nn.Conv2D(latent_channels, latent_channels, 1) if use_post_quant_conv else None diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 03145a666..9e40537e9 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -387,6 +387,7 @@ def forward( @classmethod def custom_modify_weight(cls, state_dict): + # SD3 num_layers is 24 for i in range(24): base_map_sd3 = [ (f"linear1.{i}.weight", f"{i}.norm1.linear.weight"), diff --git a/ppdiffusers/ppdiffusers/transformers/t5/modeling.py b/ppdiffusers/ppdiffusers/transformers/t5/modeling.py index ede6e5784..4dab1d130 100644 --- a/ppdiffusers/ppdiffusers/transformers/t5/modeling.py +++ b/ppdiffusers/ppdiffusers/transformers/t5/modeling.py @@ -1557,7 +1557,7 @@ def __init__(self, config: T5Config): self.post_init() # NOTE:(changwenbin,zhoukangkang) - # When you use 'paddle.incubate.jit.inference' to reconstruct the model, + # When you use 'paddle.incubate.jit.inference' to speed up your model, # if you have set 'cache_static_model=True', # you can use 'del self.encoder' to reduce the global memory usage. # del self.encoder From ded06bf803f019ea3025237e08857d812f75e6ad Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 19 Sep 2024 09:41:23 +0000 Subject: [PATCH 49/65] add Notes --- ppdiffusers/ppdiffusers/models/autoencoder_kl.py | 5 +++-- ppdiffusers/ppdiffusers/models/transformer_sd3.py | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/autoencoder_kl.py b/ppdiffusers/ppdiffusers/models/autoencoder_kl.py index 8ad84281c..e273094a1 100644 --- a/ppdiffusers/ppdiffusers/models/autoencoder_kl.py +++ b/ppdiffusers/ppdiffusers/models/autoencoder_kl.py @@ -89,7 +89,8 @@ def __init__( use_quant_conv: bool = True, use_post_quant_conv: bool = True, ): - # set USE_PPXFORMERS=False to avoid using ppxformers + # NOTE:(changwenbin,zhoukangkang) SD3 vae use memory_efficient_attention op which is not well supported by Paddle-TensorRT + # so set USE_PPXFORMERS=False to avoid using memory_efficient_attention op. os.environ["USE_PPXFORMERS"] = "False" super().__init__() # if down_block_out_channels not given, we will use block_out_channels @@ -120,7 +121,7 @@ def __init__( act_fn=act_fn, ) del os.environ["USE_PPXFORMERS"] - # del set USE_PPXFORMERS=False to Restore Defaults + # NOTE:(changwenbin,zhoukangkang) del set USE_PPXFORMERS=False to Restore Defaults self.quant_conv = nn.Conv2D(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None self.post_quant_conv = nn.Conv2D(latent_channels, latent_channels, 1) if use_post_quant_conv else None diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 9e40537e9..b5c7a4286 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -387,8 +387,9 @@ def forward( @classmethod def custom_modify_weight(cls, state_dict): - # SD3 num_layers is 24 - for i in range(24): + # NOTE:(changwenbin,zhoukangkang) SD3 num_layers is 24 + sd3_num_layers = 24 + for i in range(sd3_num_layers): base_map_sd3 = [ (f"linear1.{i}.weight", f"{i}.norm1.linear.weight"), (f"linear1.{i}.bias", f"{i}.norm1.linear.bias"), @@ -413,7 +414,7 @@ def custom_modify_weight(cls, state_dict): (f"ffn2.{i}.weight", f"{i}.ff.net.2.weight"), (f"ffn2.{i}.bias", f"{i}.ff.net.2.bias"), ] - if i < 23: + if i < sd3_num_layers - 1: extra_map_sd3 = [ (f"to_add_out_linear.{i}.weight", f"{i}.attn.to_add_out.weight"), (f"to_add_out_linear.{i}.bias", f"{i}.attn.to_add_out.bias"), From d845da2e4d0b79c426f2e41392b08a6ca0f445b5 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 19 Sep 2024 10:16:03 +0000 Subject: [PATCH 50/65] update demo --- .../text_to_image_generation-stable_diffusion_3.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index b1d638627..f4eae3c27 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -38,12 +38,6 @@ def parse_args(): default=False, help="If inference_optimize is set to True, all optimizations except Triton are enabled.", ) - parser.add_argument( - "--inference_optimize_triton", - type=(lambda x: str(x).lower() in ["true", "1", "yes"]), - default=False, - help="If inference_optimize_triton is set to True, Triton operator optimized inference is enabled.", - ) parser.add_argument( "--inference_optimize_origin", type=(lambda x: str(x).lower() in ["true", "1", "yes"]), @@ -62,7 +56,6 @@ def parse_args(): if args.inference_optimize: os.environ["INFERENCE_OPTIMIZE"] = "True" -if args.inference_optimize_triton: os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True" if args.inference_optimize_origin: os.environ["INFERENCE_OPTIMIZE_ORIGIN"] = "True" @@ -108,7 +101,7 @@ def parse_args(): pipe.transformer, save_model_dir="./tmp/sd3", enable_new_ir=True, - cache_static_model=True, + cache_static_model=False, exp_enable_use_cutlass=True, delete_pass_lists=["add_norm_fuse_pass"], ) From db6aad108d391b54a4923ea2ed6524af2d59dc92 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 19 Sep 2024 10:18:34 +0000 Subject: [PATCH 51/65] update doc --- ppdiffusers/deploy/sd3/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ppdiffusers/deploy/sd3/README.md b/ppdiffusers/deploy/sd3/README.md index 791130bf0..5d84c3507 100644 --- a/ppdiffusers/deploy/sd3/README.md +++ b/ppdiffusers/deploy/sd3/README.md @@ -23,12 +23,12 @@ export LD_LIBRARY_PATH=/your_dir/Paddle/paddle/phi/kernels/fusion/cutlass/gemm_e ```shell # step1: 生成FP32的paddle模型,同时根据Paddle模型生成FP16的TensorRT engine。 python text_to_image_generation-stable_diffusion_3.py --dtype float32 --height 512 --width 512 \ ---num-inference-steps 50 --inference_optimize 1 --inference_optimize_triton 1 \ +--num-inference-steps 50 --inference_optimize 1 \ --benchmark 1 # step2: 执行FP16推理 python text_to_image_generation-stable_diffusion_3.py --dtype float16 --height 512 --width 512 \ ---num-inference-steps 50 --inference_optimize 1 --inference_optimize_triton 1 \ +--num-inference-steps 50 --inference_optimize 1 \ --benchmark 1 ``` From 33f37aef11acdb5d3238052632dd83fd6872b307 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 19 Sep 2024 21:02:11 +0800 Subject: [PATCH 52/65] first commit --- ..._to_image_generation-stable_diffusion_3.py | 39 ------------------- .../ppdiffusers/models/autoencoder_kl.py | 6 --- 2 files changed, 45 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index f4eae3c27..562be5743 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -69,34 +69,6 @@ def parse_args(): paddle_dtype=inference_dtype, ) -pipe.text_encoder = paddle.incubate.jit.inference( - pipe.text_encoder, - save_model_dir="./tmp/text_encoder", - cache_static_model=True, - with_trt=True, - trt_precision_mode="float16", - trt_use_static=True, -) - -pipe.text_encoder_2 = paddle.incubate.jit.inference( - pipe.text_encoder_2, - save_model_dir="./tmp/text_encoder_2", - cache_static_model=True, - with_trt=True, - trt_precision_mode="float16", - trt_use_static=True, -) - - -pipe.text_encoder_3 = paddle.incubate.jit.inference( - pipe.text_encoder_3, - save_model_dir="./tmp/text_encoder_3_T5", - cache_static_model=True, - with_trt=True, - trt_precision_mode="float16", - trt_use_static=True, -) - pipe.transformer = paddle.incubate.jit.inference( pipe.transformer, save_model_dir="./tmp/sd3", @@ -106,17 +78,6 @@ def parse_args(): delete_pass_lists=["add_norm_fuse_pass"], ) - -# for vae model -pipe.vae.decode = paddle.incubate.jit.inference( - pipe.vae.decode, - save_model_dir="./tmp/vae_static_models", - cache_static_model=True, - with_trt=True, - trt_precision_mode="float16", - trt_use_static=True, -) - generator = paddle.Generator().manual_seed(42) prompt = "A cat holding a sign that says hello world" diff --git a/ppdiffusers/ppdiffusers/models/autoencoder_kl.py b/ppdiffusers/ppdiffusers/models/autoencoder_kl.py index e273094a1..d8fda325a 100644 --- a/ppdiffusers/ppdiffusers/models/autoencoder_kl.py +++ b/ppdiffusers/ppdiffusers/models/autoencoder_kl.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Dict, Optional, Tuple, Union import paddle @@ -89,9 +88,6 @@ def __init__( use_quant_conv: bool = True, use_post_quant_conv: bool = True, ): - # NOTE:(changwenbin,zhoukangkang) SD3 vae use memory_efficient_attention op which is not well supported by Paddle-TensorRT - # so set USE_PPXFORMERS=False to avoid using memory_efficient_attention op. - os.environ["USE_PPXFORMERS"] = "False" super().__init__() # if down_block_out_channels not given, we will use block_out_channels _down_block_out_channels = block_out_channels if down_block_out_channels is None else down_block_out_channels @@ -120,8 +116,6 @@ def __init__( norm_num_groups=norm_num_groups, act_fn=act_fn, ) - del os.environ["USE_PPXFORMERS"] - # NOTE:(changwenbin,zhoukangkang) del set USE_PPXFORMERS=False to Restore Defaults self.quant_conv = nn.Conv2D(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None self.post_quant_conv = nn.Conv2D(latent_channels, latent_channels, 1) if use_post_quant_conv else None From 74e04196e3dd10e1bb4b426415ddc3282f95978a Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 19 Sep 2024 21:04:46 +0800 Subject: [PATCH 53/65] first commit --- ppdiffusers/ppdiffusers/transformers/t5/modeling.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/ppdiffusers/ppdiffusers/transformers/t5/modeling.py b/ppdiffusers/ppdiffusers/transformers/t5/modeling.py index 4dab1d130..c3fd10cd3 100644 --- a/ppdiffusers/ppdiffusers/transformers/t5/modeling.py +++ b/ppdiffusers/ppdiffusers/transformers/t5/modeling.py @@ -24,7 +24,6 @@ from paddle import nn from paddle.amp.auto_cast import amp_state from paddle.distributed import fleet -from paddle.framework import in_dynamic_or_pir_mode from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from paddlenlp.transformers.activations import ACT2FN from paddlenlp.transformers.conversion_utils import ( @@ -1556,12 +1555,6 @@ def __init__(self, config: T5Config): # Initialize weights and apply final processing self.post_init() - # NOTE:(changwenbin,zhoukangkang) - # When you use 'paddle.incubate.jit.inference' to speed up your model, - # if you have set 'cache_static_model=True', - # you can use 'del self.encoder' to reduce the global memory usage. - # del self.encoder - def get_input_embeddings(self): return self.shared @@ -1612,11 +1605,7 @@ def forward( return_dict=return_dict, ) - if in_dynamic_or_pir_mode(): - return encoder_output - else: - # NOTE:(changwenbin,zhoukangkang)there is a bug in dy2s,we fix it here. - return encoder_output.last_hidden_state + return encoder_output class T5ForSequenceClassification(T5PretrainedModel): From c036878c9b7f183dd9e9d0bf70b68a77bdd3f07b Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 19 Sep 2024 21:09:27 +0800 Subject: [PATCH 54/65] commit --- .../pipeline_stable_diffusion_3.py | 41 ++++--------------- 1 file changed, 8 insertions(+), 33 deletions(-) diff --git a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 193b57b21..e96ac2fd9 100644 --- a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -19,7 +19,6 @@ import paddle from ppdiffusers.transformers import ( # T5TokenizerFast, - CLIPTextModelOutput, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, @@ -111,7 +110,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class StableDiffusion3Pipeline(DiffusionPipeline, FromSingleFileMixin): +class StableDiffusion3Pipeline(DiffusionPipeline, FromSingleFileMixin): # SD3LoraLoaderMixin r""" Args: @@ -221,13 +220,8 @@ def _get_t5_prompt_embeds( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) - - outputs = self.text_encoder_3(text_input_ids) - if paddle.incubate.jit.is_inference_mode(self.text_encoder_3): - # NOTE:(changwenbin,zhoukangkang) this is for paddle.incubate.jit.inference - prompt_embeds = outputs - else: - prompt_embeds = outputs[0] + # breakpoint() + prompt_embeds = self.text_encoder_3(text_input_ids)[0] dtype = self.text_encoder_3.dtype prompt_embeds = prompt_embeds.astype(dtype=dtype) @@ -274,23 +268,13 @@ def _get_clip_prompt_embeds( f" {self.tokenizer_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] - if paddle.incubate.jit.is_inference_mode(text_encoder): - # NOTE:(changwenbin,zhoukangkang) this is for paddle.incubate.jit.inference - pooled_prompt_embeds = prompt_embeds[-1] - if clip_skip is None: - prompt_embeds = prompt_embeds[:-2][-2] - else: - prompt_embeds = prompt_embeds[:-2][-(clip_skip + 2)] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] else: - pooled_prompt_embeds = prompt_embeds[0] - - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] - else: - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - pooled_prompt_embeds = pooled_prompt_embeds.astype(dtype=text_encoder.dtype) prompt_embeds = prompt_embeds.astype(dtype=self.text_encoder.dtype) _, seq_len, _ = prompt_embeds.shape @@ -809,7 +793,6 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - # in order to d2s noise_pred_out = self.transformer( hidden_states=latent_model_input, timestep=timestep, @@ -856,15 +839,7 @@ def __call__( else: latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - # in order to d2s - if paddle.incubate.jit.is_inference_mode(self.vae.decode): - latents = latents.cast("float32") - image_out = self.vae.decode(latents, return_dict=False) - if paddle.incubate.jit.is_inference_mode(self.vae.decode): - # NOTE:(changwenbin,zhoukangkang) this is for paddle.incubate.jit.inference - image = image_out - else: - image = image_out[0] + image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models From d879bd8f61762d5c9d3f75653c9984cff959fbb2 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 19 Sep 2024 21:17:22 +0800 Subject: [PATCH 55/65] commit --- ppdiffusers/deploy/sd3/README.md | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/ppdiffusers/deploy/sd3/README.md b/ppdiffusers/deploy/sd3/README.md index 5d84c3507..018238674 100644 --- a/ppdiffusers/deploy/sd3/README.md +++ b/ppdiffusers/deploy/sd3/README.md @@ -21,12 +21,7 @@ export LD_LIBRARY_PATH=/your_dir/Paddle/paddle/phi/kernels/fusion/cutlass/gemm_e 高性能推理指令: ```shell -# step1: 生成FP32的paddle模型,同时根据Paddle模型生成FP16的TensorRT engine。 -python text_to_image_generation-stable_diffusion_3.py --dtype float32 --height 512 --width 512 \ ---num-inference-steps 50 --inference_optimize 1 \ ---benchmark 1 - -# step2: 执行FP16推理 +# 执行FP16推理 python text_to_image_generation-stable_diffusion_3.py --dtype float16 --height 512 --width 512 \ --num-inference-steps 50 --inference_optimize 1 \ --benchmark 1 From e4367eb4018340336bf5efed71df696079cddfc7 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 19 Sep 2024 22:07:01 +0800 Subject: [PATCH 56/65] commit --- ppdiffusers/ppdiffusers/models/simplified_sd3.py | 12 ++++++------ .../pipeline_stable_diffusion_3.py | 11 ++++++++++- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 3901b3f54..6ad8e3572 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -109,14 +109,14 @@ def forward(self, hidden_states, encoder_hidden_states, temb): norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False) norm_hidden_states1 = norm_hidden_states1.reshape([2, -1, self.dim]) - # attn_output, context_attn_output = paddle.split( - # norm_hidden_states1, num_or_sections=[hidden_states.shape[1], encoder_hidden_states.shape[1]], axis=1 - # ) - - attn_output, context_attn_output = paddlemix.triton_ops.triton_split( - norm_hidden_states1, num_or_sections=[1024, 154], axis=1 + attn_output, context_attn_output = paddle.split( + norm_hidden_states1, num_or_sections=[hidden_states.shape[1], encoder_hidden_states.shape[1]], axis=1 ) + # attn_output, context_attn_output = paddlemix.triton_ops.triton_split( + # norm_hidden_states1, num_or_sections=[1024, 154], axis=1 + # ) + attn_output = paddle.nn.functional.linear( attn_output, self.to_out_linear[i].weight, self.to_out_linear[i].bias ) diff --git a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index e96ac2fd9..d1978d775 100644 --- a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -35,6 +35,15 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import StableDiffusion3PipelineOutput +try: + # paddle.incubate.jit.inference is available in paddle develop but not in paddle 3.0beta, so we add a try except. + from paddle.incubate.jit import is_inference_mode +except: + + def is_inference_mode(func): + return False + + logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ @@ -802,7 +811,7 @@ def __call__( return_dict=False, ) - if isinstance(noise_pred_out, paddle.Tensor): + if is_inference_mode(self.transformer): noise_pred = noise_pred_out else: noise_pred = noise_pred_out[0] From 842370170624decda01e0b788a4caa8737e4a2d3 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 19 Sep 2024 23:59:36 +0800 Subject: [PATCH 57/65] commit --- .../inference/text_to_image_generation-stable_diffusion_3.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index 562be5743..7e8b0e877 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -73,9 +73,8 @@ def parse_args(): pipe.transformer, save_model_dir="./tmp/sd3", enable_new_ir=True, - cache_static_model=False, + cache_static_model=True, exp_enable_use_cutlass=True, - delete_pass_lists=["add_norm_fuse_pass"], ) generator = paddle.Generator().manual_seed(42) From b73f89951b61b091ae401ce4142c4076d5e993f9 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 20 Sep 2024 00:00:10 +0800 Subject: [PATCH 58/65] commit --- ppdiffusers/ppdiffusers/models/transformer_sd3.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index b5c7a4286..3824143d4 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -341,11 +341,7 @@ def forward( out = self.simplified_sd3( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ) - # this is for paddle inference. - if isinstance(out, paddle.Tensor): - hidden_states = out - else: - hidden_states = out[1] + hidden_states = out[1] encoder_hidden_states = None elif self.inference_optimize_origin: From 841e739597a83f0d7443085581b9a4179198d5ce Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 20 Sep 2024 00:03:18 +0800 Subject: [PATCH 59/65] commit --- ppdiffusers/ppdiffusers/models/simplified_sd3.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 6ad8e3572..fbf60feef 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -52,6 +52,9 @@ def forward(self, hidden_states, encoder_hidden_states, temb): last_context_hidden_states = None last_context_gate_mlp = None + seq1 = hidden_states.shape[1] + seq2 = encoder_hidden_states.shape[1] + for i in range(self.num_layers): context_pre_only = i == self.num_layers - 1 @@ -109,9 +112,7 @@ def forward(self, hidden_states, encoder_hidden_states, temb): norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False) norm_hidden_states1 = norm_hidden_states1.reshape([2, -1, self.dim]) - attn_output, context_attn_output = paddle.split( - norm_hidden_states1, num_or_sections=[hidden_states.shape[1], encoder_hidden_states.shape[1]], axis=1 - ) + attn_output, context_attn_output = paddle.split(norm_hidden_states1, num_or_sections=[seq1, seq2], axis=1) # attn_output, context_attn_output = paddlemix.triton_ops.triton_split( # norm_hidden_states1, num_or_sections=[1024, 154], axis=1 From fd2c0d913a749481a158994e2ebba5e240a0560f Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 20 Sep 2024 00:13:17 +0800 Subject: [PATCH 60/65] commit --- paddlemix/triton_ops/triton_ops.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/paddlemix/triton_ops/triton_ops.py b/paddlemix/triton_ops/triton_ops.py index 41fc718bf..e03fd3166 100644 --- a/paddlemix/triton_ops/triton_ops.py +++ b/paddlemix/triton_ops/triton_ops.py @@ -1620,7 +1620,13 @@ def fused_rotary_emb( std::vector> ${op_name}_InferShape( const std::vector& A_shape, const std::vector& B_shape) { - std::vector out_shape = {A_shape[0], A_shape[1]+B_shape[1], A_shape[2]/3}; + int64_t seq1 = A_shape[1]; + int64_t seq2 = B_shape[1]; + int64_t seq = -1; + if (seq1 > 0 && seq2 > 0){ + seq = seq1 + seq2; + } + std::vector out_shape = {A_shape[0], seq, A_shape[2]/3}; return {out_shape, out_shape, out_shape}; } From bf70669f33123b7d44f1948d56b4f8214ea2a501 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 20 Sep 2024 02:02:32 +0800 Subject: [PATCH 61/65] commit --- .../ppdiffusers/models/transformer_sd3.py | 47 ++++++------------- 1 file changed, 15 insertions(+), 32 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 3824143d4..5777c8c8f 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -383,6 +383,10 @@ def forward( @classmethod def custom_modify_weight(cls, state_dict): + + if os.getenv("INFERENCE_OPTIMIZE") != "True": + return + # NOTE:(changwenbin,zhoukangkang) SD3 num_layers is 24 sd3_num_layers = 24 for i in range(sd3_num_layers): @@ -427,35 +431,14 @@ def custom_modify_weight(cls, state_dict): else: print(f"Warning!!: '{from_}' not found in state_dict") - state_dict[f"simplified_sd3.qkv.{i}.weight"] = paddle.concat( - [ - state_dict[f"simplified_sd3.q.{i}.weight"], - state_dict[f"simplified_sd3.k.{i}.weight"], - state_dict[f"simplified_sd3.v.{i}.weight"], - ], - axis=1, - ) - state_dict[f"simplified_sd3.qkv.{i}.bias"] = paddle.concat( - [ - state_dict[f"simplified_sd3.q.{i}.bias"], - state_dict[f"simplified_sd3.k.{i}.bias"], - state_dict[f"simplified_sd3.v.{i}.bias"], - ], - axis=0, - ) - state_dict[f"simplified_sd3.eqkv.{i}.weight"] = paddle.concat( - [ - state_dict[f"simplified_sd3.eq.{i}.weight"], - state_dict[f"simplified_sd3.ek.{i}.weight"], - state_dict[f"simplified_sd3.ev.{i}.weight"], - ], - axis=1, - ) - state_dict[f"simplified_sd3.eqkv.{i}.bias"] = paddle.concat( - [ - state_dict[f"simplified_sd3.eq.{i}.bias"], - state_dict[f"simplified_sd3.ek.{i}.bias"], - state_dict[f"simplified_sd3.ev.{i}.bias"], - ], - axis=0, - ) + # concat qkv weight and bias. + for placeholder1 in ["", "e"]: + for placeholder2 in ["weight", "bias"]: + state_dict[f"simplified_sd3.{placeholder1}qkv.{i}.{placeholder2}"] = paddle.concat( + [ + state_dict[f"simplified_sd3.{placeholder1}q.{i}.{placeholder2}"], + state_dict[f"simplified_sd3.{placeholder1}k.{i}.{placeholder2}"], + state_dict[f"simplified_sd3.{placeholder1}v.{i}.{placeholder2}"], + ], + axis=-1, + ) From 7d8064cb7de574a8b485c35bd71391532d683bce Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 20 Sep 2024 02:24:40 +0800 Subject: [PATCH 62/65] commit --- ppdiffusers/ppdiffusers/models/normalization.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ppdiffusers/ppdiffusers/models/normalization.py b/ppdiffusers/ppdiffusers/models/normalization.py index 33e715675..68a732cc0 100644 --- a/ppdiffusers/ppdiffusers/models/normalization.py +++ b/ppdiffusers/ppdiffusers/models/normalization.py @@ -192,6 +192,8 @@ def forward(self, x: paddle.Tensor, conditioning_embedding: paddle.Tensor) -> pa emb = self.linear(self.silu(conditioning_embedding).cast(x.dtype)) scale, shift = paddle.chunk(emb, 2, axis=1) if os.getenv("INFERENCE_OPTIMIZE_TRITON"): + # NOTE:(changwenbin,zhoukangkang) + # This is a fused faster op using Triton, only used in inference, not used in training. import paddlemix x = paddlemix.triton_ops.adaptive_layer_norm(x, scale, shift, self.norm.weight, self.norm.bias) From e2e287975e71cb6d7c581fc3596ff19f3223bb59 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 20 Sep 2024 09:13:12 +0800 Subject: [PATCH 63/65] remove inference_optimize_origin --- ...ext_to_image_generation-stable_diffusion_3.py | 8 -------- .../ppdiffusers/models/transformer_sd3.py | 16 ---------------- 2 files changed, 24 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index 7e8b0e877..e2eb1cd5f 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -38,12 +38,6 @@ def parse_args(): default=False, help="If inference_optimize is set to True, all optimizations except Triton are enabled.", ) - parser.add_argument( - "--inference_optimize_origin", - type=(lambda x: str(x).lower() in ["true", "1", "yes"]), - default=False, - help="If inference_optimize_origin is set to True, the original dynamic graph inference optimization is enabled.", - ) parser.add_argument("--height", type=int, default=512, help="Height of the generated image.") parser.add_argument("--width", type=int, default=512, help="Width of the generated image.") parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of inference steps.") @@ -57,8 +51,6 @@ def parse_args(): if args.inference_optimize: os.environ["INFERENCE_OPTIMIZE"] = "True" os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True" -if args.inference_optimize_origin: - os.environ["INFERENCE_OPTIMIZE_ORIGIN"] = "True" if args.dtype == "float32": inference_dtype = paddle.float32 elif args.dtype == "float16": diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 5777c8c8f..c276d4d21 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -95,7 +95,6 @@ def __init__( self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim) self.inference_optimize = os.getenv("INFERENCE_OPTIMIZE") == "True" - self.inference_optimize_origin = os.getenv("INFERENCE_OPTIMIZE_ORIGIN") == "True" # `attention_head_dim` is doubled to account for the mixing. # It needs to crafted when we get the actual checkpoints. self.transformer_blocks = nn.LayerList( @@ -118,14 +117,6 @@ def __init__( num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.inner_dim, ) - elif self.inference_optimize_origin: - self.sd3_origin_transformer = paddle.incubate.jit.inference( - self.sd3_origin_transformer, - enable_new_ir=True, - cache_static_model=False, - exp_enable_use_cutlass=True, - delete_pass_lists=["add_norm_fuse_pass"], - ) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias_attr=True) @@ -343,13 +334,6 @@ def forward( ) hidden_states = out[1] encoder_hidden_states = None - - elif self.inference_optimize_origin: - hidden_states = self.sd3_origin_transformer( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb - ) - encoder_hidden_states = None - else: encoder_hidden_states, hidden_states = self.sd3_origin_transformer( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb From 2aba4ecc65996b57a2b7e26e03da81d054789787 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 20 Sep 2024 11:01:34 +0800 Subject: [PATCH 64/65] commit --- ppdiffusers/deploy/sd3/README.md | 9 +++------ .../stable_diffusion_3/pipeline_stable_diffusion_3.py | 6 +++--- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/ppdiffusers/deploy/sd3/README.md b/ppdiffusers/deploy/sd3/README.md index 018238674..5e59bbb07 100644 --- a/ppdiffusers/deploy/sd3/README.md +++ b/ppdiffusers/deploy/sd3/README.md @@ -8,14 +8,11 @@ python -m pip install triton python -m pip install git+https://github.com/zhoutianzi666/UseTritonInPaddle.git python -c "import use_triton_in_paddle; use_triton_in_paddle.make_triton_compatible_with_paddle()" -# 安装develop版本的paddle +# 安装develop版本的paddle,请根据自己的cuda版本选择对应的paddle版本,这里选择12.3的cuda版本 python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ -# 指定Tensor-RT的lib路径 -export LD_LIBRARY_PATH=/your_TensorRT_dir//lib:$LD_LIBRARY_PATH - -# 指定cutlass包路径 -export LD_LIBRARY_PATH=/your_dir/Paddle/paddle/phi/kernels/fusion/cutlass/conv2d/build:$LD_LIBRARY_PATH +# 指定 libCutlassGemmEpilogue.so 的路径 +# 详情请参考 https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/fusion/cutlass/gemm_epilogue/README.md export LD_LIBRARY_PATH=/your_dir/Paddle/paddle/phi/kernels/fusion/cutlass/gemm_epilogue/build:$LD_LIBRARY_PATH ``` diff --git a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index d1978d775..134f648d8 100644 --- a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -802,7 +802,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - noise_pred_out = self.transformer( + model_output = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, @@ -812,9 +812,9 @@ def __call__( ) if is_inference_mode(self.transformer): - noise_pred = noise_pred_out + noise_pred = model_output else: - noise_pred = noise_pred_out[0] + noise_pred = model_output[0] # perform guidance if self.do_classifier_free_guidance: From b291a7e9883e40742ec944eb9f038cc824675083 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 20 Sep 2024 11:05:18 +0800 Subject: [PATCH 65/65] commit --- .../pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 134f648d8..000803b1b 100644 --- a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -812,6 +812,8 @@ def __call__( ) if is_inference_mode(self.transformer): + # NOTE:(changwenbin,zhoukangkang) + # This is for paddle inference mode noise_pred = model_output else: noise_pred = model_output[0]