From 59f23a0da9d078ff477519f99db71e125c47826d Mon Sep 17 00:00:00 2001 From: changwenbin Date: Mon, 29 Jul 2024 03:58:24 +0000 Subject: [PATCH 01/24] modified the dit --- .../ppdiffusers/models/modeling_utils.py | 5 + .../ppdiffusers/models/transformer_2d.py | 164 ++++++++++++++---- 2 files changed, 133 insertions(+), 36 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/modeling_utils.py b/ppdiffusers/ppdiffusers/models/modeling_utils.py index e4462d284..3bcb620a6 100644 --- a/ppdiffusers/ppdiffusers/models/modeling_utils.py +++ b/ppdiffusers/ppdiffusers/models/modeling_utils.py @@ -1050,6 +1050,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P return model + @classmethod + def custom_modify_weight(cls, state_dict): + pass + @classmethod def _load_pretrained_model( cls, @@ -1130,6 +1134,7 @@ def _find_mismatched_keys( error_msgs.append( f"Error size mismatch, {key_name} receives a shape {loaded_shape}, but the expected shape is {model_shape}." ) + cls.custom_modify_weight(state_dict) faster_set_state_dict(model_to_load, state_dict) missing_keys = sorted(list(set(expected_keys) - set(loaded_keys))) diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index 3a1084d03..a3278a04a 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -28,6 +28,8 @@ recompute_use_reentrant, use_old_recompute, ) +from .zkk_facebook_dit import ZKKFacebookDIT + from .attention import BasicTransformerBlock from .embeddings import CaptionProjection, PatchEmbed from .lora import LoRACompatibleConv, LoRACompatibleLinear @@ -213,6 +215,8 @@ def __init__( for d in range(num_layers) ] ) + + self.tmp_ZKKFacebookDIT = ZKKFacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels @@ -385,41 +389,44 @@ def forward( encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]]) - for block in self.transformer_blocks: - if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} - hidden_states = recompute( - create_custom_forward(block), - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - timestep, - cross_attention_kwargs, - class_labels, - **ckpt_kwargs, - ) - else: - hidden_states = block( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - + # for block in self.transformer_blocks: + # if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): + + # def create_custom_forward(module, return_dict=None): + # def custom_forward(*inputs): + # if return_dict is not None: + # return module(*inputs, return_dict=return_dict) + # else: + # return module(*inputs) + + # return custom_forward + + # ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} + # hidden_states = recompute( + # create_custom_forward(block), + # hidden_states, + # attention_mask, + # encoder_hidden_states, + # encoder_attention_mask, + # timestep, + # cross_attention_kwargs, + # class_labels, + # **ckpt_kwargs, + # ) + # else: + # hidden_states = block( + # hidden_states, + # attention_mask=attention_mask, + # encoder_hidden_states=encoder_hidden_states, + # encoder_attention_mask=encoder_attention_mask, + # timestep=timestep, + # cross_attention_kwargs=cross_attention_kwargs, + # class_labels=class_labels, + # ) + + + hidden_states = self.tmp_ZKKFacebookDIT(hidden_states, timestep, class_labels) + # 3. Output if self.is_input_continuous: if not self.use_linear_projection: @@ -473,7 +480,8 @@ def custom_forward(*inputs): hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) - hidden_states = paddle.einsum("nhwpqc->nchpwq", hidden_states) + #hidden_states = paddle.einsum("nhwpqc->nchpwq", hidden_states) + hidden_states = hidden_states.transpose([0,5,1,3,2,4]) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) @@ -482,3 +490,87 @@ def custom_forward(*inputs): return (output,) return Transformer2DModelOutput(sample=output) + + @classmethod + def custom_modify_weight(cls, state_dict): + for key in list(state_dict.keys()): + if 'attn1.to_q.weight' in key or 'attn1.to_k.weight' in key or 'attn1.to_v.weight' in key: + part = key.split('.')[-2] + layer_id = key.split('.')[1] + qkv_key_w = f'transformer_blocks.{layer_id}.attn1.to_qkv.weight' + if part == 'to_q' and qkv_key_w not in state_dict: + state_dict[qkv_key_w] = state_dict.pop(key) + elif part in ('to_k', 'to_v'): + qkv = state_dict.get(qkv_key_w) + if qkv is not None: + state_dict[qkv_key_w] = paddle.concat([qkv, state_dict.pop(key)], axis=-1) + if 'attn1.to_q.bias' in key or 'attn1.to_k.bias' in key or 'attn1.to_v.bias' in key: + part = key.split('.')[-2] + layer_id = key.split('.')[1] + qkv_key_b = f'transformer_blocks.{layer_id}.attn1.to_qkv.bias' + if part == 'to_q' and qkv_key_b not in state_dict: + state_dict[qkv_key_b] = state_dict.pop(key) + elif part in ('to_k', 'to_v'): + qkv = state_dict.get(qkv_key_b) + if qkv is not None: + state_dict[qkv_key_b] = paddle.concat([qkv, state_dict.pop(key)], axis=-1) + + for key in list(state_dict.keys()): + name = "" + if 'attn1.to_qkv.weight' in key: + layer_id = (int)(key.split(".")[1]) + name = f'tmp_ZKKFacebookDIT.qkv.{layer_id}.weight'.format(layer_id) + if 'attn1.to_qkv.bias' in key: + layer_id = (int)(key.split(".")[1]) + name = f'tmp_ZKKFacebookDIT.qkv.{layer_id}.bias'.format(layer_id) + + if 'attn1.to_out.0.weight' in key: + layer_id = (int)(key.split(".")[1]) + name = f'tmp_ZKKFacebookDIT.out_proj.{layer_id}.weight'.format(layer_id) + if 'attn1.to_out.0.bias' in key: + layer_id = (int)(key.split(".")[1]) + name = f'tmp_ZKKFacebookDIT.out_proj.{layer_id}.bias'.format(layer_id) + + if 'ff.net.0.proj.weight' in key: + layer_id = (int)(key.split(".")[1]) + name = f'tmp_ZKKFacebookDIT.ffn1.{layer_id}.weight'.format(layer_id) + if 'ff.net.0.proj.bias' in key: + layer_id = (int)(key.split(".")[1]) + name = f'tmp_ZKKFacebookDIT.ffn1.{layer_id}.bias'.format(layer_id) + + if 'ff.net.2.weight' in key: + layer_id = (int)(key.split(".")[1]) + name = f'tmp_ZKKFacebookDIT.ffn2.{layer_id}.weight'.format(layer_id) + if 'ff.net.2.bias' in key: + layer_id = (int)(key.split(".")[1]) + name = f'tmp_ZKKFacebookDIT.ffn2.{layer_id}.bias'.format(layer_id) + + + if 'norm1.emb.timestep_embedder.linear_1.weight' in key: + layer_id = (int)(key.split(".")[1]) + name = f'tmp_ZKKFacebookDIT.fcs0.{layer_id}.weight'.format(layer_id) + if 'norm1.emb.timestep_embedder.linear_1.bias' in key: + layer_id = (int)(key.split(".")[1]) + name = f'tmp_ZKKFacebookDIT.fcs0.{layer_id}.bias'.format(layer_id) + + + if 'norm1.emb.timestep_embedder.linear_2.weight' in key: + layer_id = (int)(key.split(".")[1]) + name = f'tmp_ZKKFacebookDIT.fcs1.{layer_id}.weight'.format(layer_id) + if 'norm1.emb.timestep_embedder.linear_2.bias' in key: + layer_id = (int)(key.split(".")[1]) + name = f'tmp_ZKKFacebookDIT.fcs1.{layer_id}.bias'.format(layer_id) + + + if 'norm1.linear.weight' in key: + layer_id = (int)(key.split(".")[1]) + name = f'tmp_ZKKFacebookDIT.fcs2.{layer_id}.weight'.format(layer_id) + if 'norm1.linear.bias' in key: + layer_id = (int)(key.split(".")[1]) + name = f'tmp_ZKKFacebookDIT.fcs2.{layer_id}.bias'.format(layer_id) + + if 'class_embedder.embedding_table.weight' in key: + layer_id = (int)(key.split(".")[1]) + name = f'tmp_ZKKFacebookDIT.embs.{layer_id}.weight'.format(layer_id) + + state_dict[name] = paddle.assign(state_dict[key]) From 5fee64b214a6f4ffc508093dfda34bec67973da6 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Mon, 29 Jul 2024 04:24:16 +0000 Subject: [PATCH 02/24] add zkk_facebook --- .../ppdiffusers/models/zkk_facebook_dit.py | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 ppdiffusers/ppdiffusers/models/zkk_facebook_dit.py diff --git a/ppdiffusers/ppdiffusers/models/zkk_facebook_dit.py b/ppdiffusers/ppdiffusers/models/zkk_facebook_dit.py new file mode 100644 index 000000000..f5b42d7af --- /dev/null +++ b/ppdiffusers/ppdiffusers/models/zkk_facebook_dit.py @@ -0,0 +1,69 @@ +from paddle import nn +import paddle +import paddle.nn.functional as F + +class ZKKFacebookDIT(nn.Layer): + def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): + super().__init__() + self.num_layers = num_layers + self.dtype = "float16" + self.attention_head_dim = attention_head_dim + + self.fcs0 = nn.LayerList([nn.Linear(256, 1152) for i in range(self.num_layers)]) + self.fcs1 = nn.LayerList([nn.Linear(1152, 1152) for i in range(self.num_layers)]) + self.fcs2 = nn.LayerList([nn.Linear(1152, 6912) for i in range(self.num_layers)]) + self.embs = nn.LayerList([nn.Embedding(1001, 1152) for i in range(self.num_layers)]) + + self.qkv = nn.LayerList([nn.Linear(dim, dim * 3) for i in range(self.num_layers)]) + self.out_proj = nn.LayerList([nn.Linear(dim, dim) for i in range(self.num_layers)]) + self.ffn1 = nn.LayerList([nn.Linear(dim, dim*4) for i in range(self.num_layers)]) + self.ffn2 = nn.LayerList([nn.Linear(dim*4, dim) for i in range(self.num_layers)]) + + @paddle.incubate.jit.inference(enable_new_ir=True, + cache_static_model=True, + exp_enable_use_cutlass=True, + delete_pass_lists=["add_norm_fuse_pass"], + ) + def forward(self,hidden_states, timestep, class_labels): + + tmp = paddle.arange(dtype='float32', end=128) + tmp = tmp * -9.21034049987793 * 0.007874015718698502 + tmp = paddle.exp(tmp).reshape([1,128]) + + timestep = timestep.cast("float32") + timestep = timestep.reshape([2,1]) + + tmp = tmp * timestep + + tmp = paddle.concat([paddle.cos(tmp), paddle.sin(tmp)], axis=-1) + common_tmp = tmp.cast(self.dtype) + + for i in range(self.num_layers): + tmp = self.fcs0[i](common_tmp) + tmp = F.silu(tmp) + tmp = self.fcs1[i](tmp) + tmp = tmp + self.embs[i](class_labels) + tmp = F.silu(tmp) + tmp = self.fcs2[i](tmp) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = tmp.chunk(6, axis=1) + norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_msa, shift_msa) + q,k,v = self.qkv[i](norm_hidden_states).chunk(3, axis=-1) + q = q.reshape([2,256,16,72]) + k = k.reshape([2,256,16,72]) + v = v.reshape([2,256,16,72]) + + norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.attention_head_dim**-0.5) + norm_hidden_states = norm_hidden_states.reshape([2,256,1152]) + norm_hidden_states = self.out_proj[i](norm_hidden_states) + + hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([2,1,1152]) + + norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp) + + norm_hidden_states = self.ffn1[i](norm_hidden_states) + norm_hidden_states = F.gelu(norm_hidden_states, approximate=True) + norm_hidden_states = self.ffn2[i](norm_hidden_states) + + hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([2,1,1152]) + + return hidden_states \ No newline at end of file From f653a66bac9ff75c4f32177a7acb7a3baf8bb300 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Mon, 29 Jul 2024 12:53:59 +0000 Subject: [PATCH 03/24] update zkk_facebook_dit.py --- .../ppdiffusers/models/zkk_facebook_dit.py | 125 ++++++++++++------ 1 file changed, 87 insertions(+), 38 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/zkk_facebook_dit.py b/ppdiffusers/ppdiffusers/models/zkk_facebook_dit.py index f5b42d7af..04069d3b4 100644 --- a/ppdiffusers/ppdiffusers/models/zkk_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/zkk_facebook_dit.py @@ -1,18 +1,34 @@ from paddle import nn import paddle import paddle.nn.functional as F +import math class ZKKFacebookDIT(nn.Layer): def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): super().__init__() self.num_layers = num_layers self.dtype = "float16" + self.dim = dim + self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim - - self.fcs0 = nn.LayerList([nn.Linear(256, 1152) for i in range(self.num_layers)]) - self.fcs1 = nn.LayerList([nn.Linear(1152, 1152) for i in range(self.num_layers)]) - self.fcs2 = nn.LayerList([nn.Linear(1152, 6912) for i in range(self.num_layers)]) - self.embs = nn.LayerList([nn.Embedding(1001, 1152) for i in range(self.num_layers)]) + self.timestep_embedder_in_channels = 256 + self.timestep_embedder_time_embed_dim = 1152 + self.timestep_embedder_time_embed_dim_out = self.timestep_embedder_time_embed_dim + self.CombinedTimestepLabelEmbeddings_num_embeddings = 1001 + self.CombinedTimestepLabelEmbeddings_embedding_dim = 1152 + + self.fcs0 = nn.LayerList([nn.Linear(self.timestep_embedder_in_channels, + self.timestep_embedder_time_embed_dim) for i in range(self.num_layers)]) + + self.fcs1 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim, + self.timestep_embedder_time_embed_dim_out) for i in range(self.num_layers)]) + + self.fcs2 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim, + 6 * self.timestep_embedder_time_embed_dim) for i in range(self.num_layers)]) + + self.embs = nn.LayerList([nn.Embedding(self.CombinedTimestepLabelEmbeddings_embedding_dim, + self.CombinedTimestepLabelEmbeddings_num_embeddings) for i in range(self.num_layers)]) + self.qkv = nn.LayerList([nn.Linear(dim, dim * 3) for i in range(self.num_layers)]) self.out_proj = nn.LayerList([nn.Linear(dim, dim) for i in range(self.num_layers)]) @@ -24,46 +40,79 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio exp_enable_use_cutlass=True, delete_pass_lists=["add_norm_fuse_pass"], ) - def forward(self,hidden_states, timestep, class_labels): - - tmp = paddle.arange(dtype='float32', end=128) - tmp = tmp * -9.21034049987793 * 0.007874015718698502 - tmp = paddle.exp(tmp).reshape([1,128]) - - timestep = timestep.cast("float32") - timestep = timestep.reshape([2,1]) - - tmp = tmp * timestep - - tmp = paddle.concat([paddle.cos(tmp), paddle.sin(tmp)], axis=-1) - common_tmp = tmp.cast(self.dtype) - - for i in range(self.num_layers): - tmp = self.fcs0[i](common_tmp) - tmp = F.silu(tmp) - tmp = self.fcs1[i](tmp) - tmp = tmp + self.embs[i](class_labels) - tmp = F.silu(tmp) - tmp = self.fcs2[i](tmp) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = tmp.chunk(6, axis=1) + def forward(self, hidden_states, timesteps, class_labels): + + # below code are copied from PaddleMIX/ppdiffusers/ppdiffusers/models/embeddings.py + num_channels = 256 + max_period = 10000 + downscale_freq_shift = 1 + half_dim = num_channels // 2 + exponent = -math.log(max_period) * paddle.arange(start=0, end=half_dim, dtype="float32") + exponent = exponent / (half_dim - downscale_freq_shift) + emb = paddle.exp(exponent) + emb = timesteps[:, None].cast("float32") * emb[None, :] + emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1) + common_emb = emb.cast(self.dtype) + + for i in range(self.num_layers): #$$ for? + emb = self.fcs0[i](common_emb) + emb = F.silu(emb) + emb = self.fcs1[i](emb) + emb = emb + self.embs[i](class_labels) + emb = F.silu(emb) + emb = self.fcs2[i](emb) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_msa, shift_msa) q,k,v = self.qkv[i](norm_hidden_states).chunk(3, axis=-1) - q = q.reshape([2,256,16,72]) - k = k.reshape([2,256,16,72]) - v = v.reshape([2,256,16,72]) - + b,s,h = q.shape + q = q.reshape([b,s,self.num_attention_heads,self.attention_head_dim]) + k = k.reshape([b,s,self.num_attention_heads,self.attention_head_dim]) + v = v.reshape([b,s,self.num_attention_heads,self.attention_head_dim]) + norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.attention_head_dim**-0.5) - norm_hidden_states = norm_hidden_states.reshape([2,256,1152]) + norm_hidden_states = norm_hidden_states.reshape([b,s,self.dim]) norm_hidden_states = self.out_proj[i](norm_hidden_states) - - hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([2,1,1152]) - norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp) + # hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([b,1,self.dim]) + # norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp) + + hidden_states,norm_hidden_states = paddle.incubate.tt.fused_adaLN_scale_residual(hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp) norm_hidden_states = self.ffn1[i](norm_hidden_states) norm_hidden_states = F.gelu(norm_hidden_states, approximate=True) norm_hidden_states = self.ffn2[i](norm_hidden_states) - hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([2,1,1152]) - - return hidden_states \ No newline at end of file + hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([b,1,self.dim]) + + return hidden_states + + + # tmp = paddle.arange(dtype='float32', end=128) + # tmp = tmp * -9.21034049987793 * 0.007874015718698502 + # tmp = paddle.exp(tmp).reshape([1,128]) + # timestep = timestep.cast("float32") + # timestep = timestep.reshape([2,1]) + # tmp = tmp * timestep + # tmp = paddle.concat([paddle.cos(tmp), paddle.sin(tmp)], axis=-1) + # common_tmp = tmp.cast(self.dtype) + # for i in range(self.num_layers): + # tmp = self.fcs0[i](common_tmp) + # tmp = F.silu(tmp) + # tmp = self.fcs1[i](tmp) + # tmp = tmp + self.embs[i](class_labels) + # tmp = F.silu(tmp) + # tmp = self.fcs2[i](tmp) + # shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = tmp.chunk(6, axis=1) + # norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_msa, shift_msa) + # q,k,v = self.qkv[i](norm_hidden_states).chunk(3, axis=-1) + # orm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.attention_head_dim**-0.5) + # norm_hidden_states = norm_hidden_states.reshape([2,256,1152]) + # orm_hidden_states = self.out_proj[i](norm_hidden_states) + # hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([2,1,1152]) + # norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp) + # norm_hidden_states = self.ffn1[i](norm_hidden_states) + # norm_hidden_states = F.gelu(norm_hidden_states, approximate=True) + # norm_hidden_states = self.ffn2[i](norm_hidden_states) + # hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([2,1,1152]) + # return hidden_states + \ No newline at end of file From 3b29d9df24feb401fcf6d843d62215f6c9ffe155 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Tue, 30 Jul 2024 04:17:59 +0000 Subject: [PATCH 04/24] update transformer_2d --- .../ppdiffusers/models/transformer_2d.py | 82 ++++++------------- 1 file changed, 23 insertions(+), 59 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index a3278a04a..97454dd91 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -515,62 +515,26 @@ def custom_modify_weight(cls, state_dict): if qkv is not None: state_dict[qkv_key_b] = paddle.concat([qkv, state_dict.pop(key)], axis=-1) - for key in list(state_dict.keys()): - name = "" - if 'attn1.to_qkv.weight' in key: - layer_id = (int)(key.split(".")[1]) - name = f'tmp_ZKKFacebookDIT.qkv.{layer_id}.weight'.format(layer_id) - if 'attn1.to_qkv.bias' in key: - layer_id = (int)(key.split(".")[1]) - name = f'tmp_ZKKFacebookDIT.qkv.{layer_id}.bias'.format(layer_id) - - if 'attn1.to_out.0.weight' in key: - layer_id = (int)(key.split(".")[1]) - name = f'tmp_ZKKFacebookDIT.out_proj.{layer_id}.weight'.format(layer_id) - if 'attn1.to_out.0.bias' in key: - layer_id = (int)(key.split(".")[1]) - name = f'tmp_ZKKFacebookDIT.out_proj.{layer_id}.bias'.format(layer_id) - - if 'ff.net.0.proj.weight' in key: - layer_id = (int)(key.split(".")[1]) - name = f'tmp_ZKKFacebookDIT.ffn1.{layer_id}.weight'.format(layer_id) - if 'ff.net.0.proj.bias' in key: - layer_id = (int)(key.split(".")[1]) - name = f'tmp_ZKKFacebookDIT.ffn1.{layer_id}.bias'.format(layer_id) - - if 'ff.net.2.weight' in key: - layer_id = (int)(key.split(".")[1]) - name = f'tmp_ZKKFacebookDIT.ffn2.{layer_id}.weight'.format(layer_id) - if 'ff.net.2.bias' in key: - layer_id = (int)(key.split(".")[1]) - name = f'tmp_ZKKFacebookDIT.ffn2.{layer_id}.bias'.format(layer_id) - - - if 'norm1.emb.timestep_embedder.linear_1.weight' in key: - layer_id = (int)(key.split(".")[1]) - name = f'tmp_ZKKFacebookDIT.fcs0.{layer_id}.weight'.format(layer_id) - if 'norm1.emb.timestep_embedder.linear_1.bias' in key: - layer_id = (int)(key.split(".")[1]) - name = f'tmp_ZKKFacebookDIT.fcs0.{layer_id}.bias'.format(layer_id) - - - if 'norm1.emb.timestep_embedder.linear_2.weight' in key: - layer_id = (int)(key.split(".")[1]) - name = f'tmp_ZKKFacebookDIT.fcs1.{layer_id}.weight'.format(layer_id) - if 'norm1.emb.timestep_embedder.linear_2.bias' in key: - layer_id = (int)(key.split(".")[1]) - name = f'tmp_ZKKFacebookDIT.fcs1.{layer_id}.bias'.format(layer_id) - - - if 'norm1.linear.weight' in key: - layer_id = (int)(key.split(".")[1]) - name = f'tmp_ZKKFacebookDIT.fcs2.{layer_id}.weight'.format(layer_id) - if 'norm1.linear.bias' in key: - layer_id = (int)(key.split(".")[1]) - name = f'tmp_ZKKFacebookDIT.fcs2.{layer_id}.bias'.format(layer_id) - - if 'class_embedder.embedding_table.weight' in key: - layer_id = (int)(key.split(".")[1]) - name = f'tmp_ZKKFacebookDIT.embs.{layer_id}.weight'.format(layer_id) - - state_dict[name] = paddle.assign(state_dict[key]) + map_from_my_dit = {} + for i in range(28): + map_from_my_dit[f'tmp_ZKKFacebookDIT.qkv.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_qkv.weight' + map_from_my_dit[f'tmp_ZKKFacebookDIT.qkv.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_qkv.bias' + map_from_my_dit[f'tmp_ZKKFacebookDIT.out_proj.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_out.0.weight' + map_from_my_dit[f'tmp_ZKKFacebookDIT.out_proj.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_out.0.bias' + map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn1.{i}.weight'] = f'transformer_blocks.{i}.ff.net.0.proj.weight' + map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn1.{i}.bias'] = f'transformer_blocks.{i}.ff.net.0.proj.bias' + map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn2.{i}.weight'] = f'transformer_blocks.{i}.ff.net.2.weight' + map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn2.{i}.bias'] = f'transformer_blocks.{i}.ff.net.2.bias' + + map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs0.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight' + map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs0.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias' + map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs1.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight' + map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs1.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias' + map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs2.{i}.weight'] = f'transformer_blocks.{i}.norm1.linear.weight' + map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs2.{i}.bias'] = f'transformer_blocks.{i}.norm1.linear.bias' + + map_from_my_dit[f'tmp_ZKKFacebookDIT.embs.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight' + + for key in map_from_my_dit.keys(): + state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]]) + \ No newline at end of file From a88caeaa64ea629ce3fddc7a510816d178f7dab0 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 31 Jul 2024 05:45:10 +0000 Subject: [PATCH 05/24] update dit optimize --- .../class_conditional_image_generation-dit.py | 40 +++++- ...kk_facebook_dit.py => sim_facebook_dit.py} | 36 +----- .../ppdiffusers/models/transformer_2d.py | 120 +++++++++--------- 3 files changed, 103 insertions(+), 93 deletions(-) rename ppdiffusers/ppdiffusers/models/{zkk_facebook_dit.py => sim_facebook_dit.py} (71%) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index 71d73ec0d..911f95cab 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "2" import paddle from paddlenlp.trainer import set_seed - from ppdiffusers import DDIMScheduler, DiTPipeline -dtype = paddle.float32 +os.environ["Inference_Optimize"] = "True" + +dtype = paddle.float16 pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) set_seed(42) @@ -27,4 +30,35 @@ image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] -image.save("class_conditional_image_generation-dit-result.png") +# image.save("class_conditional_image_generation-dit-result.png") +image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] +image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] + + +import datetime +import time + +warm_up_times = 5 +repeat_times = 10 +sum_time = 0. + +for i in range(repeat_times): + paddle.device.synchronize() + starttime = datetime.datetime.now() + image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] + paddle.device.synchronize() + endtime = datetime.datetime.now() + duringtime = endtime - starttime + time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 + evet = "every_time: " + str(time_ms) + "ms\n\n" + with open("/cwb/wenbin/PaddleMIX/ppdiffusers/examples/inference/Aibin/time_729.txt", "a") as time_file: + time_file.write(evet) + sum_time+=time_ms +print("The ave end to end time : ", sum_time / repeat_times, "ms") +msg = "average_time: " + str(sum_time / repeat_times) + "ms\n\n" +print(msg) +with open("/cwb/wenbin/PaddleMIX/ppdiffusers/examples/inference/Aibin/time_729.txt", "a") as time_file: + time_file.write(msg) + +image.save("class_conditional_image_generation-dit-29.png") + diff --git a/ppdiffusers/ppdiffusers/models/zkk_facebook_dit.py b/ppdiffusers/ppdiffusers/models/sim_facebook_dit.py similarity index 71% rename from ppdiffusers/ppdiffusers/models/zkk_facebook_dit.py rename to ppdiffusers/ppdiffusers/models/sim_facebook_dit.py index 04069d3b4..43815c903 100644 --- a/ppdiffusers/ppdiffusers/models/zkk_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/sim_facebook_dit.py @@ -3,7 +3,7 @@ import paddle.nn.functional as F import math -class ZKKFacebookDIT(nn.Layer): +class SIM_FacebookDIT(nn.Layer): def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): super().__init__() self.num_layers = num_layers @@ -54,7 +54,7 @@ def forward(self, hidden_states, timesteps, class_labels): emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1) common_emb = emb.cast(self.dtype) - for i in range(self.num_layers): #$$ for? + for i in range(self.num_layers): emb = self.fcs0[i](common_emb) emb = F.silu(emb) emb = self.fcs1[i](emb) @@ -85,34 +85,4 @@ def forward(self, hidden_states, timesteps, class_labels): hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([b,1,self.dim]) return hidden_states - - - # tmp = paddle.arange(dtype='float32', end=128) - # tmp = tmp * -9.21034049987793 * 0.007874015718698502 - # tmp = paddle.exp(tmp).reshape([1,128]) - # timestep = timestep.cast("float32") - # timestep = timestep.reshape([2,1]) - # tmp = tmp * timestep - # tmp = paddle.concat([paddle.cos(tmp), paddle.sin(tmp)], axis=-1) - # common_tmp = tmp.cast(self.dtype) - # for i in range(self.num_layers): - # tmp = self.fcs0[i](common_tmp) - # tmp = F.silu(tmp) - # tmp = self.fcs1[i](tmp) - # tmp = tmp + self.embs[i](class_labels) - # tmp = F.silu(tmp) - # tmp = self.fcs2[i](tmp) - # shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = tmp.chunk(6, axis=1) - # norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_msa, shift_msa) - # q,k,v = self.qkv[i](norm_hidden_states).chunk(3, axis=-1) - # orm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.attention_head_dim**-0.5) - # norm_hidden_states = norm_hidden_states.reshape([2,256,1152]) - # orm_hidden_states = self.out_proj[i](norm_hidden_states) - # hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([2,1,1152]) - # norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp) - # norm_hidden_states = self.ffn1[i](norm_hidden_states) - # norm_hidden_states = F.gelu(norm_hidden_states, approximate=True) - # norm_hidden_states = self.ffn2[i](norm_hidden_states) - # hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([2,1,1152]) - # return hidden_states - \ No newline at end of file + \ No newline at end of file diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index 97454dd91..5259b0954 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -28,13 +28,15 @@ recompute_use_reentrant, use_old_recompute, ) -from .zkk_facebook_dit import ZKKFacebookDIT +from .sim_facebook_dit import SIM_FacebookDIT from .attention import BasicTransformerBlock from .embeddings import CaptionProjection, PatchEmbed from .lora import LoRACompatibleConv, LoRACompatibleLinear from .modeling_utils import ModelMixin from .normalization import AdaLayerNormSingle +import os + @dataclass @@ -116,6 +118,8 @@ def __init__( self.inner_dim = inner_dim = num_attention_heads * attention_head_dim self.data_format = data_format + self.Inference_Optimize = bool(os.getenv('Inference_Optimize')) + conv_cls = nn.Conv2D if USE_PEFT_BACKEND else LoRACompatibleConv linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear @@ -216,7 +220,7 @@ def __init__( ] ) - self.tmp_ZKKFacebookDIT = ZKKFacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) + self.FacebookDIT = SIM_FacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels @@ -254,6 +258,7 @@ def __init__( self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim) self.gradient_checkpointing = False + def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): @@ -388,45 +393,46 @@ def forward( batch_size = hidden_states.shape[0] encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]]) - - # for block in self.transformer_blocks: - # if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): - - # def create_custom_forward(module, return_dict=None): - # def custom_forward(*inputs): - # if return_dict is not None: - # return module(*inputs, return_dict=return_dict) - # else: - # return module(*inputs) - - # return custom_forward - - # ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} - # hidden_states = recompute( - # create_custom_forward(block), - # hidden_states, - # attention_mask, - # encoder_hidden_states, - # encoder_attention_mask, - # timestep, - # cross_attention_kwargs, - # class_labels, - # **ckpt_kwargs, - # ) - # else: - # hidden_states = block( - # hidden_states, - # attention_mask=attention_mask, - # encoder_hidden_states=encoder_hidden_states, - # encoder_attention_mask=encoder_attention_mask, - # timestep=timestep, - # cross_attention_kwargs=cross_attention_kwargs, - # class_labels=class_labels, - # ) - - - hidden_states = self.tmp_ZKKFacebookDIT(hidden_states, timestep, class_labels) + + if self.Inference_Optimize is True: + hidden_states = self.FacebookDIT(hidden_states, timestep, class_labels) + else: + for block in self.transformer_blocks: + if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} + hidden_states = recompute( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + # 3. Output if self.is_input_continuous: if not self.use_linear_projection: @@ -517,23 +523,23 @@ def custom_modify_weight(cls, state_dict): map_from_my_dit = {} for i in range(28): - map_from_my_dit[f'tmp_ZKKFacebookDIT.qkv.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_qkv.weight' - map_from_my_dit[f'tmp_ZKKFacebookDIT.qkv.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_qkv.bias' - map_from_my_dit[f'tmp_ZKKFacebookDIT.out_proj.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_out.0.weight' - map_from_my_dit[f'tmp_ZKKFacebookDIT.out_proj.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_out.0.bias' - map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn1.{i}.weight'] = f'transformer_blocks.{i}.ff.net.0.proj.weight' - map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn1.{i}.bias'] = f'transformer_blocks.{i}.ff.net.0.proj.bias' - map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn2.{i}.weight'] = f'transformer_blocks.{i}.ff.net.2.weight' - map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn2.{i}.bias'] = f'transformer_blocks.{i}.ff.net.2.bias' - - map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs0.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight' - map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs0.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias' - map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs1.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight' - map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs1.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias' - map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs2.{i}.weight'] = f'transformer_blocks.{i}.norm1.linear.weight' - map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs2.{i}.bias'] = f'transformer_blocks.{i}.norm1.linear.bias' - - map_from_my_dit[f'tmp_ZKKFacebookDIT.embs.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight' + map_from_my_dit[f'FacebookDIT.qkv.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_qkv.weight' + map_from_my_dit[f'FacebookDIT.qkv.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_qkv.bias' + map_from_my_dit[f'FacebookDIT.out_proj.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_out.0.weight' + map_from_my_dit[f'FacebookDIT.out_proj.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_out.0.bias' + map_from_my_dit[f'FacebookDIT.ffn1.{i}.weight'] = f'transformer_blocks.{i}.ff.net.0.proj.weight' + map_from_my_dit[f'FacebookDIT.ffn1.{i}.bias'] = f'transformer_blocks.{i}.ff.net.0.proj.bias' + map_from_my_dit[f'FacebookDIT.ffn2.{i}.weight'] = f'transformer_blocks.{i}.ff.net.2.weight' + map_from_my_dit[f'FacebookDIT.ffn2.{i}.bias'] = f'transformer_blocks.{i}.ff.net.2.bias' + + map_from_my_dit[f'FacebookDIT.fcs0.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight' + map_from_my_dit[f'FacebookDIT.fcs0.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias' + map_from_my_dit[f'FacebookDIT.fcs1.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight' + map_from_my_dit[f'FacebookDIT.fcs1.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias' + map_from_my_dit[f'FacebookDIT.fcs2.{i}.weight'] = f'transformer_blocks.{i}.norm1.linear.weight' + map_from_my_dit[f'FacebookDIT.fcs2.{i}.bias'] = f'transformer_blocks.{i}.norm1.linear.bias' + + map_from_my_dit[f'FacebookDIT.embs.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight' for key in map_from_my_dit.keys(): state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]]) From 54eeec2cb5446bd3ab45476729b3f04b973c58ce Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 31 Jul 2024 05:53:48 +0000 Subject: [PATCH 06/24] update transformer_2d --- ppdiffusers/ppdiffusers/models/transformer_2d.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index 5259b0954..ea9a59da6 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -219,8 +219,8 @@ def __init__( for d in range(num_layers) ] ) - - self.FacebookDIT = SIM_FacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) + if self.Inference_Optimize: + self.FacebookDIT = SIM_FacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels @@ -395,7 +395,7 @@ def forward( encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]]) - if self.Inference_Optimize is True: + if self.Inference_Optimize: hidden_states = self.FacebookDIT(hidden_states, timestep, class_labels) else: for block in self.transformer_blocks: @@ -486,8 +486,7 @@ def custom_forward(*inputs): hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) - #hidden_states = paddle.einsum("nhwpqc->nchpwq", hidden_states) - hidden_states = hidden_states.transpose([0,5,1,3,2,4]) + hidden_states = paddle.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) From 28a62c0a7994f8746e7d6d4805a1a222f2465f36 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 1 Aug 2024 05:10:01 +0000 Subject: [PATCH 07/24] rename facebook_dit --- .../class_conditional_image_generation-dit.py | 8 +++- ...book_dit.py => simplified_facebook_dit.py} | 12 +++--- .../ppdiffusers/models/transformer_2d.py | 41 +++++++++---------- 3 files changed, 32 insertions(+), 29 deletions(-) rename ppdiffusers/ppdiffusers/models/{sim_facebook_dit.py => simplified_facebook_dit.py} (90%) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index 911f95cab..6249af008 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -18,8 +18,12 @@ from paddlenlp.trainer import set_seed from ppdiffusers import DDIMScheduler, DiTPipeline -os.environ["Inference_Optimize"] = "True" - +Inference_Optimize = True +if Inference_Optimize: + os.environ["Inference_Optimize"] = "True" +else: + pass + dtype = paddle.float16 pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) diff --git a/ppdiffusers/ppdiffusers/models/sim_facebook_dit.py b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py similarity index 90% rename from ppdiffusers/ppdiffusers/models/sim_facebook_dit.py rename to ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py index 43815c903..c56ce21a0 100644 --- a/ppdiffusers/ppdiffusers/models/sim_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py @@ -3,7 +3,7 @@ import paddle.nn.functional as F import math -class SIM_FacebookDIT(nn.Layer): +class Simplified_FacebookDIT(nn.Layer): def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): super().__init__() self.num_layers = num_layers @@ -14,8 +14,8 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.timestep_embedder_in_channels = 256 self.timestep_embedder_time_embed_dim = 1152 self.timestep_embedder_time_embed_dim_out = self.timestep_embedder_time_embed_dim - self.CombinedTimestepLabelEmbeddings_num_embeddings = 1001 - self.CombinedTimestepLabelEmbeddings_embedding_dim = 1152 + self.LabelEmbedding_num_classes = 1001 + self.LabelEmbedding_num_hidden_size = 1152 self.fcs0 = nn.LayerList([nn.Linear(self.timestep_embedder_in_channels, self.timestep_embedder_time_embed_dim) for i in range(self.num_layers)]) @@ -26,8 +26,8 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.fcs2 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim, 6 * self.timestep_embedder_time_embed_dim) for i in range(self.num_layers)]) - self.embs = nn.LayerList([nn.Embedding(self.CombinedTimestepLabelEmbeddings_embedding_dim, - self.CombinedTimestepLabelEmbeddings_num_embeddings) for i in range(self.num_layers)]) + self.embs = nn.LayerList([nn.Embedding(self.LabelEmbedding_num_classes, + self.LabelEmbedding_num_hidden_size) for i in range(self.num_layers)]) self.qkv = nn.LayerList([nn.Linear(dim, dim * 3) for i in range(self.num_layers)]) @@ -36,7 +36,7 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.ffn2 = nn.LayerList([nn.Linear(dim*4, dim) for i in range(self.num_layers)]) @paddle.incubate.jit.inference(enable_new_ir=True, - cache_static_model=True, + cache_static_model=False, exp_enable_use_cutlass=True, delete_pass_lists=["add_norm_fuse_pass"], ) diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index ea9a59da6..c120188d1 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -28,7 +28,7 @@ recompute_use_reentrant, use_old_recompute, ) -from .sim_facebook_dit import SIM_FacebookDIT +from .simplified_facebook_dit import Simplified_FacebookDIT from .attention import BasicTransformerBlock from .embeddings import CaptionProjection, PatchEmbed @@ -220,7 +220,7 @@ def __init__( ] ) if self.Inference_Optimize: - self.FacebookDIT = SIM_FacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) + self.Simplified_FacebookDIT = Simplified_FacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels @@ -394,9 +394,8 @@ def forward( encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]]) - if self.Inference_Optimize: - hidden_states = self.FacebookDIT(hidden_states, timestep, class_labels) + hidden_states =self.Simplified_FacebookDIT(hidden_states, timestep, class_labels) else: for block in self.transformer_blocks: if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): @@ -522,23 +521,23 @@ def custom_modify_weight(cls, state_dict): map_from_my_dit = {} for i in range(28): - map_from_my_dit[f'FacebookDIT.qkv.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_qkv.weight' - map_from_my_dit[f'FacebookDIT.qkv.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_qkv.bias' - map_from_my_dit[f'FacebookDIT.out_proj.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_out.0.weight' - map_from_my_dit[f'FacebookDIT.out_proj.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_out.0.bias' - map_from_my_dit[f'FacebookDIT.ffn1.{i}.weight'] = f'transformer_blocks.{i}.ff.net.0.proj.weight' - map_from_my_dit[f'FacebookDIT.ffn1.{i}.bias'] = f'transformer_blocks.{i}.ff.net.0.proj.bias' - map_from_my_dit[f'FacebookDIT.ffn2.{i}.weight'] = f'transformer_blocks.{i}.ff.net.2.weight' - map_from_my_dit[f'FacebookDIT.ffn2.{i}.bias'] = f'transformer_blocks.{i}.ff.net.2.bias' - - map_from_my_dit[f'FacebookDIT.fcs0.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight' - map_from_my_dit[f'FacebookDIT.fcs0.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias' - map_from_my_dit[f'FacebookDIT.fcs1.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight' - map_from_my_dit[f'FacebookDIT.fcs1.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias' - map_from_my_dit[f'FacebookDIT.fcs2.{i}.weight'] = f'transformer_blocks.{i}.norm1.linear.weight' - map_from_my_dit[f'FacebookDIT.fcs2.{i}.bias'] = f'transformer_blocks.{i}.norm1.linear.bias' - - map_from_my_dit[f'FacebookDIT.embs.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight' + map_from_my_dit[f'Simplified_FacebookDIT.qkv.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_qkv.weight' + map_from_my_dit[f'Simplified_FacebookDIT.qkv.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_qkv.bias' + map_from_my_dit[f'Simplified_FacebookDIT.out_proj.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_out.0.weight' + map_from_my_dit[f'Simplified_FacebookDIT.out_proj.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_out.0.bias' + map_from_my_dit[f'Simplified_FacebookDIT.ffn1.{i}.weight'] = f'transformer_blocks.{i}.ff.net.0.proj.weight' + map_from_my_dit[f'Simplified_FacebookDIT.ffn1.{i}.bias'] = f'transformer_blocks.{i}.ff.net.0.proj.bias' + map_from_my_dit[f'Simplified_FacebookDIT.ffn2.{i}.weight'] = f'transformer_blocks.{i}.ff.net.2.weight' + map_from_my_dit[f'Simplified_FacebookDIT.ffn2.{i}.bias'] = f'transformer_blocks.{i}.ff.net.2.bias' + + map_from_my_dit[f'Simplified_FacebookDIT.fcs0.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight' + map_from_my_dit[f'Simplified_FacebookDIT.fcs0.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias' + map_from_my_dit[f'Simplified_FacebookDIT.fcs1.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight' + map_from_my_dit[f'Simplified_FacebookDIT.fcs1.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias' + map_from_my_dit[f'Simplified_FacebookDIT.fcs2.{i}.weight'] = f'transformer_blocks.{i}.norm1.linear.weight' + map_from_my_dit[f'Simplified_FacebookDIT.fcs2.{i}.bias'] = f'transformer_blocks.{i}.norm1.linear.bias' + + map_from_my_dit[f'Simplified_FacebookDIT.embs.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight' for key in map_from_my_dit.keys(): state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]]) From 7d49c490ed6f351da5b17f48f72be050382436d8 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Mon, 5 Aug 2024 03:41:16 +0000 Subject: [PATCH 08/24] Fixed the original dynamic image bug --- ppdiffusers/ppdiffusers/models/modeling_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ppdiffusers/ppdiffusers/models/modeling_utils.py b/ppdiffusers/ppdiffusers/models/modeling_utils.py index 3bcb620a6..b31bdd284 100644 --- a/ppdiffusers/ppdiffusers/models/modeling_utils.py +++ b/ppdiffusers/ppdiffusers/models/modeling_utils.py @@ -1134,7 +1134,8 @@ def _find_mismatched_keys( error_msgs.append( f"Error size mismatch, {key_name} receives a shape {loaded_shape}, but the expected shape is {model_shape}." ) - cls.custom_modify_weight(state_dict) + if os.getenv('Inference_Optimize'): + cls.custom_modify_weight(state_dict) faster_set_state_dict(model_to_load, state_dict) missing_keys = sorted(list(set(expected_keys) - set(loaded_keys))) From b03aa8ea9981db53fbef7e296019ec9d1fac52eb Mon Sep 17 00:00:00 2001 From: changwenbin Date: Mon, 5 Aug 2024 04:28:50 +0000 Subject: [PATCH 09/24] update triton op import paddlemix --- .../ppdiffusers/models/simplified_facebook_dit.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py index c56ce21a0..1556d45dc 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py @@ -36,7 +36,7 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.ffn2 = nn.LayerList([nn.Linear(dim*4, dim) for i in range(self.num_layers)]) @paddle.incubate.jit.inference(enable_new_ir=True, - cache_static_model=False, + cache_static_model=True, exp_enable_use_cutlass=True, delete_pass_lists=["add_norm_fuse_pass"], ) @@ -62,7 +62,8 @@ def forward(self, hidden_states, timesteps, class_labels): emb = F.silu(emb) emb = self.fcs2[i](emb) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) - norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_msa, shift_msa) + import paddlemix + norm_hidden_states =paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_msa, shift_msa) q,k,v = self.qkv[i](norm_hidden_states).chunk(3, axis=-1) b,s,h = q.shape q = q.reshape([b,s,self.num_attention_heads,self.attention_head_dim]) @@ -74,9 +75,8 @@ def forward(self, hidden_states, timesteps, class_labels): norm_hidden_states = self.out_proj[i](norm_hidden_states) # hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([b,1,self.dim]) - # norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp) - - hidden_states,norm_hidden_states = paddle.incubate.tt.fused_adaLN_scale_residual(hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp) + # norm_hidden_states =paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp) + hidden_states,norm_hidden_states =paddlemix.triton_ops.fused_adaLN_scale_residual(hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp) norm_hidden_states = self.ffn1[i](norm_hidden_states) norm_hidden_states = F.gelu(norm_hidden_states, approximate=True) From cb86d1789a5ebee35b8bcd95e35faef641fdd767 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 7 Aug 2024 05:38:15 +0000 Subject: [PATCH 10/24] update dit --- .../ppdiffusers/models/modeling_utils.py | 3 +- .../models/simplified_facebook_dit.py | 5 +- .../ppdiffusers/models/transformer_2d.py | 90 +++++++++---------- 3 files changed, 48 insertions(+), 50 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/modeling_utils.py b/ppdiffusers/ppdiffusers/models/modeling_utils.py index b31bdd284..3bcb620a6 100644 --- a/ppdiffusers/ppdiffusers/models/modeling_utils.py +++ b/ppdiffusers/ppdiffusers/models/modeling_utils.py @@ -1134,8 +1134,7 @@ def _find_mismatched_keys( error_msgs.append( f"Error size mismatch, {key_name} receives a shape {loaded_shape}, but the expected shape is {model_shape}." ) - if os.getenv('Inference_Optimize'): - cls.custom_modify_weight(state_dict) + cls.custom_modify_weight(state_dict) faster_set_state_dict(model_to_load, state_dict) missing_keys = sorted(list(set(expected_keys) - set(loaded_keys))) diff --git a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py index 1556d45dc..14a1982b1 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py @@ -7,7 +7,6 @@ class Simplified_FacebookDIT(nn.Layer): def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): super().__init__() self.num_layers = num_layers - self.dtype = "float16" self.dim = dim self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim @@ -36,7 +35,7 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.ffn2 = nn.LayerList([nn.Linear(dim*4, dim) for i in range(self.num_layers)]) @paddle.incubate.jit.inference(enable_new_ir=True, - cache_static_model=True, + cache_static_model=False, exp_enable_use_cutlass=True, delete_pass_lists=["add_norm_fuse_pass"], ) @@ -52,7 +51,7 @@ def forward(self, hidden_states, timesteps, class_labels): emb = paddle.exp(exponent) emb = timesteps[:, None].cast("float32") * emb[None, :] emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1) - common_emb = emb.cast(self.dtype) + common_emb = emb.cast(hidden_states.dtype) for i in range(self.num_layers): emb = self.fcs0[i](common_emb) diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index c120188d1..b82292b68 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -497,48 +497,48 @@ def custom_forward(*inputs): @classmethod def custom_modify_weight(cls, state_dict): - for key in list(state_dict.keys()): - if 'attn1.to_q.weight' in key or 'attn1.to_k.weight' in key or 'attn1.to_v.weight' in key: - part = key.split('.')[-2] - layer_id = key.split('.')[1] - qkv_key_w = f'transformer_blocks.{layer_id}.attn1.to_qkv.weight' - if part == 'to_q' and qkv_key_w not in state_dict: - state_dict[qkv_key_w] = state_dict.pop(key) - elif part in ('to_k', 'to_v'): - qkv = state_dict.get(qkv_key_w) - if qkv is not None: - state_dict[qkv_key_w] = paddle.concat([qkv, state_dict.pop(key)], axis=-1) - if 'attn1.to_q.bias' in key or 'attn1.to_k.bias' in key or 'attn1.to_v.bias' in key: - part = key.split('.')[-2] - layer_id = key.split('.')[1] - qkv_key_b = f'transformer_blocks.{layer_id}.attn1.to_qkv.bias' - if part == 'to_q' and qkv_key_b not in state_dict: - state_dict[qkv_key_b] = state_dict.pop(key) - elif part in ('to_k', 'to_v'): - qkv = state_dict.get(qkv_key_b) - if qkv is not None: - state_dict[qkv_key_b] = paddle.concat([qkv, state_dict.pop(key)], axis=-1) - - map_from_my_dit = {} - for i in range(28): - map_from_my_dit[f'Simplified_FacebookDIT.qkv.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_qkv.weight' - map_from_my_dit[f'Simplified_FacebookDIT.qkv.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_qkv.bias' - map_from_my_dit[f'Simplified_FacebookDIT.out_proj.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_out.0.weight' - map_from_my_dit[f'Simplified_FacebookDIT.out_proj.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_out.0.bias' - map_from_my_dit[f'Simplified_FacebookDIT.ffn1.{i}.weight'] = f'transformer_blocks.{i}.ff.net.0.proj.weight' - map_from_my_dit[f'Simplified_FacebookDIT.ffn1.{i}.bias'] = f'transformer_blocks.{i}.ff.net.0.proj.bias' - map_from_my_dit[f'Simplified_FacebookDIT.ffn2.{i}.weight'] = f'transformer_blocks.{i}.ff.net.2.weight' - map_from_my_dit[f'Simplified_FacebookDIT.ffn2.{i}.bias'] = f'transformer_blocks.{i}.ff.net.2.bias' - - map_from_my_dit[f'Simplified_FacebookDIT.fcs0.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight' - map_from_my_dit[f'Simplified_FacebookDIT.fcs0.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias' - map_from_my_dit[f'Simplified_FacebookDIT.fcs1.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight' - map_from_my_dit[f'Simplified_FacebookDIT.fcs1.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias' - map_from_my_dit[f'Simplified_FacebookDIT.fcs2.{i}.weight'] = f'transformer_blocks.{i}.norm1.linear.weight' - map_from_my_dit[f'Simplified_FacebookDIT.fcs2.{i}.bias'] = f'transformer_blocks.{i}.norm1.linear.bias' - - map_from_my_dit[f'Simplified_FacebookDIT.embs.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight' - - for key in map_from_my_dit.keys(): - state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]]) - \ No newline at end of file + if os.getenv('Inference_Optimize'): + for key in list(state_dict.keys()): + if 'attn1.to_q.weight' in key or 'attn1.to_k.weight' in key or 'attn1.to_v.weight' in key: + part = key.split('.')[-2] + layer_id = key.split('.')[1] + qkv_key_w = f'transformer_blocks.{layer_id}.attn1.to_qkv.weight' + if part == 'to_q' and qkv_key_w not in state_dict: + state_dict[qkv_key_w] = state_dict.pop(key) + elif part in ('to_k', 'to_v'): + qkv = state_dict.get(qkv_key_w) + if qkv is not None: + state_dict[qkv_key_w] = paddle.concat([qkv, state_dict.pop(key)], axis=-1) + if 'attn1.to_q.bias' in key or 'attn1.to_k.bias' in key or 'attn1.to_v.bias' in key: + part = key.split('.')[-2] + layer_id = key.split('.')[1] + qkv_key_b = f'transformer_blocks.{layer_id}.attn1.to_qkv.bias' + if part == 'to_q' and qkv_key_b not in state_dict: + state_dict[qkv_key_b] = state_dict.pop(key) + elif part in ('to_k', 'to_v'): + qkv = state_dict.get(qkv_key_b) + if qkv is not None: + state_dict[qkv_key_b] = paddle.concat([qkv, state_dict.pop(key)], axis=-1) + + map_from_my_dit = {} + for i in range(28): + map_from_my_dit[f'Simplified_FacebookDIT.qkv.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_qkv.weight' + map_from_my_dit[f'Simplified_FacebookDIT.qkv.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_qkv.bias' + map_from_my_dit[f'Simplified_FacebookDIT.out_proj.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_out.0.weight' + map_from_my_dit[f'Simplified_FacebookDIT.out_proj.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_out.0.bias' + map_from_my_dit[f'Simplified_FacebookDIT.ffn1.{i}.weight'] = f'transformer_blocks.{i}.ff.net.0.proj.weight' + map_from_my_dit[f'Simplified_FacebookDIT.ffn1.{i}.bias'] = f'transformer_blocks.{i}.ff.net.0.proj.bias' + map_from_my_dit[f'Simplified_FacebookDIT.ffn2.{i}.weight'] = f'transformer_blocks.{i}.ff.net.2.weight' + map_from_my_dit[f'Simplified_FacebookDIT.ffn2.{i}.bias'] = f'transformer_blocks.{i}.ff.net.2.bias' + + map_from_my_dit[f'Simplified_FacebookDIT.fcs0.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight' + map_from_my_dit[f'Simplified_FacebookDIT.fcs0.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias' + map_from_my_dit[f'Simplified_FacebookDIT.fcs1.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight' + map_from_my_dit[f'Simplified_FacebookDIT.fcs1.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias' + map_from_my_dit[f'Simplified_FacebookDIT.fcs2.{i}.weight'] = f'transformer_blocks.{i}.norm1.linear.weight' + map_from_my_dit[f'Simplified_FacebookDIT.fcs2.{i}.bias'] = f'transformer_blocks.{i}.norm1.linear.bias' + + map_from_my_dit[f'Simplified_FacebookDIT.embs.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight' + + for key in map_from_my_dit.keys(): + state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]]) From dc0c45caf1249f4f88b82c908e0e3aada5d4e631 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 7 Aug 2024 07:00:51 +0000 Subject: [PATCH 11/24] update transformer_2d & simplified_facebook_dit --- .../models/simplified_facebook_dit.py | 18 ++--- .../ppdiffusers/models/transformer_2d.py | 79 ++++++++----------- 2 files changed, 40 insertions(+), 57 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py index 14a1982b1..56394d114 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py @@ -3,7 +3,7 @@ import paddle.nn.functional as F import math -class Simplified_FacebookDIT(nn.Layer): +class SimplifiedFacebookDIT(nn.Layer): def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): super().__init__() self.num_layers = num_layers @@ -29,7 +29,9 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.LabelEmbedding_num_hidden_size) for i in range(self.num_layers)]) - self.qkv = nn.LayerList([nn.Linear(dim, dim * 3) for i in range(self.num_layers)]) + self.q = nn.LayerList([nn.Linear(dim, dim ) for i in range(self.num_layers)]) + self.k = nn.LayerList([nn.Linear(dim, dim ) for i in range(self.num_layers)]) + self.v = nn.LayerList([nn.Linear(dim, dim ) for i in range(self.num_layers)]) self.out_proj = nn.LayerList([nn.Linear(dim, dim) for i in range(self.num_layers)]) self.ffn1 = nn.LayerList([nn.Linear(dim, dim*4) for i in range(self.num_layers)]) self.ffn2 = nn.LayerList([nn.Linear(dim*4, dim) for i in range(self.num_layers)]) @@ -63,14 +65,12 @@ def forward(self, hidden_states, timesteps, class_labels): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) import paddlemix norm_hidden_states =paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_msa, shift_msa) - q,k,v = self.qkv[i](norm_hidden_states).chunk(3, axis=-1) - b,s,h = q.shape - q = q.reshape([b,s,self.num_attention_heads,self.attention_head_dim]) - k = k.reshape([b,s,self.num_attention_heads,self.attention_head_dim]) - v = v.reshape([b,s,self.num_attention_heads,self.attention_head_dim]) + q = self.q[i](norm_hidden_states).reshape([norm_hidden_states.shape[0],norm_hidden_states.shape[1],self.num_attention_heads,self.attention_head_dim]) + k = self.k[i](norm_hidden_states).reshape([norm_hidden_states.shape[0],norm_hidden_states.shape[1],self.num_attention_heads,self.attention_head_dim]) + v = self.v[i](norm_hidden_states).reshape([norm_hidden_states.shape[0],norm_hidden_states.shape[1],self.num_attention_heads,self.attention_head_dim]) norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.attention_head_dim**-0.5) - norm_hidden_states = norm_hidden_states.reshape([b,s,self.dim]) + norm_hidden_states = norm_hidden_states.reshape([norm_hidden_states.shape[0],norm_hidden_states.shape[1],self.dim]) norm_hidden_states = self.out_proj[i](norm_hidden_states) # hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([b,1,self.dim]) @@ -81,7 +81,7 @@ def forward(self, hidden_states, timesteps, class_labels): norm_hidden_states = F.gelu(norm_hidden_states, approximate=True) norm_hidden_states = self.ffn2[i](norm_hidden_states) - hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([b,1,self.dim]) + hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([norm_hidden_states.shape[0],1,self.dim]) return hidden_states \ No newline at end of file diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index b82292b68..ab8bbc3bd 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -28,7 +28,7 @@ recompute_use_reentrant, use_old_recompute, ) -from .simplified_facebook_dit import Simplified_FacebookDIT +from .simplified_facebook_dit import SimplifiedFacebookDIT from .attention import BasicTransformerBlock from .embeddings import CaptionProjection, PatchEmbed @@ -220,7 +220,7 @@ def __init__( ] ) if self.Inference_Optimize: - self.Simplified_FacebookDIT = Simplified_FacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) + self.simplified_facebookDIT = SimplifiedFacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels @@ -395,7 +395,7 @@ def forward( encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]]) if self.Inference_Optimize: - hidden_states =self.Simplified_FacebookDIT(hidden_states, timestep, class_labels) + hidden_states =self.simplified_facebookDIT(hidden_states, timestep, class_labels) else: for block in self.transformer_blocks: if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): @@ -497,48 +497,31 @@ def custom_forward(*inputs): @classmethod def custom_modify_weight(cls, state_dict): - if os.getenv('Inference_Optimize'): - for key in list(state_dict.keys()): - if 'attn1.to_q.weight' in key or 'attn1.to_k.weight' in key or 'attn1.to_v.weight' in key: - part = key.split('.')[-2] - layer_id = key.split('.')[1] - qkv_key_w = f'transformer_blocks.{layer_id}.attn1.to_qkv.weight' - if part == 'to_q' and qkv_key_w not in state_dict: - state_dict[qkv_key_w] = state_dict.pop(key) - elif part in ('to_k', 'to_v'): - qkv = state_dict.get(qkv_key_w) - if qkv is not None: - state_dict[qkv_key_w] = paddle.concat([qkv, state_dict.pop(key)], axis=-1) - if 'attn1.to_q.bias' in key or 'attn1.to_k.bias' in key or 'attn1.to_v.bias' in key: - part = key.split('.')[-2] - layer_id = key.split('.')[1] - qkv_key_b = f'transformer_blocks.{layer_id}.attn1.to_qkv.bias' - if part == 'to_q' and qkv_key_b not in state_dict: - state_dict[qkv_key_b] = state_dict.pop(key) - elif part in ('to_k', 'to_v'): - qkv = state_dict.get(qkv_key_b) - if qkv is not None: - state_dict[qkv_key_b] = paddle.concat([qkv, state_dict.pop(key)], axis=-1) - - map_from_my_dit = {} - for i in range(28): - map_from_my_dit[f'Simplified_FacebookDIT.qkv.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_qkv.weight' - map_from_my_dit[f'Simplified_FacebookDIT.qkv.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_qkv.bias' - map_from_my_dit[f'Simplified_FacebookDIT.out_proj.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_out.0.weight' - map_from_my_dit[f'Simplified_FacebookDIT.out_proj.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_out.0.bias' - map_from_my_dit[f'Simplified_FacebookDIT.ffn1.{i}.weight'] = f'transformer_blocks.{i}.ff.net.0.proj.weight' - map_from_my_dit[f'Simplified_FacebookDIT.ffn1.{i}.bias'] = f'transformer_blocks.{i}.ff.net.0.proj.bias' - map_from_my_dit[f'Simplified_FacebookDIT.ffn2.{i}.weight'] = f'transformer_blocks.{i}.ff.net.2.weight' - map_from_my_dit[f'Simplified_FacebookDIT.ffn2.{i}.bias'] = f'transformer_blocks.{i}.ff.net.2.bias' - - map_from_my_dit[f'Simplified_FacebookDIT.fcs0.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight' - map_from_my_dit[f'Simplified_FacebookDIT.fcs0.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias' - map_from_my_dit[f'Simplified_FacebookDIT.fcs1.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight' - map_from_my_dit[f'Simplified_FacebookDIT.fcs1.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias' - map_from_my_dit[f'Simplified_FacebookDIT.fcs2.{i}.weight'] = f'transformer_blocks.{i}.norm1.linear.weight' - map_from_my_dit[f'Simplified_FacebookDIT.fcs2.{i}.bias'] = f'transformer_blocks.{i}.norm1.linear.bias' - - map_from_my_dit[f'Simplified_FacebookDIT.embs.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight' - - for key in map_from_my_dit.keys(): - state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]]) + if not os.getenv('Inference_Optimize'): + return + map_from_my_dit = {} + for i in range(28): + map_from_my_dit[f'simplified_facebookDIT.q.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_q.weight' + map_from_my_dit[f'simplified_facebookDIT.k.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_k.weight' + map_from_my_dit[f'simplified_facebookDIT.v.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_v.weight' + map_from_my_dit[f'simplified_facebookDIT.q.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_q.bias' + map_from_my_dit[f'simplified_facebookDIT.k.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_k.bias' + map_from_my_dit[f'simplified_facebookDIT.v.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_v.bias' + map_from_my_dit[f'simplified_facebookDIT.out_proj.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_out.0.weight' + map_from_my_dit[f'simplified_facebookDIT.out_proj.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_out.0.bias' + map_from_my_dit[f'simplified_facebookDIT.ffn1.{i}.weight'] = f'transformer_blocks.{i}.ff.net.0.proj.weight' + map_from_my_dit[f'simplified_facebookDIT.ffn1.{i}.bias'] = f'transformer_blocks.{i}.ff.net.0.proj.bias' + map_from_my_dit[f'simplified_facebookDIT.ffn2.{i}.weight'] = f'transformer_blocks.{i}.ff.net.2.weight' + map_from_my_dit[f'simplified_facebookDIT.ffn2.{i}.bias'] = f'transformer_blocks.{i}.ff.net.2.bias' + + map_from_my_dit[f'simplified_facebookDIT.fcs0.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight' + map_from_my_dit[f'simplified_facebookDIT.fcs0.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias' + map_from_my_dit[f'simplified_facebookDIT.fcs1.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight' + map_from_my_dit[f'simplified_facebookDIT.fcs1.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias' + map_from_my_dit[f'simplified_facebookDIT.fcs2.{i}.weight'] = f'transformer_blocks.{i}.norm1.linear.weight' + map_from_my_dit[f'simplified_facebookDIT.fcs2.{i}.bias'] = f'transformer_blocks.{i}.norm1.linear.bias' + + map_from_my_dit[f'simplified_facebookDIT.embs.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight' + + for key in map_from_my_dit.keys(): + state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]]) From 42f61bcd20a63564e1158039c1e97c9f649985e1 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 7 Aug 2024 07:42:03 +0000 Subject: [PATCH 12/24] update demo & implified_facebook_dit & transformer_2d --- .../class_conditional_image_generation-dit.py | 34 +------------------ .../models/simplified_facebook_dit.py | 12 +++---- .../ppdiffusers/models/transformer_2d.py | 2 +- 3 files changed, 8 insertions(+), 40 deletions(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index 6249af008..d8929d6ad 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -13,7 +13,6 @@ # limitations under the License. import os -os.environ["CUDA_VISIBLE_DEVICES"] = "2" import paddle from paddlenlp.trainer import set_seed from ppdiffusers import DDIMScheduler, DiTPipeline @@ -34,35 +33,4 @@ image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] -# image.save("class_conditional_image_generation-dit-result.png") -image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] -image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] - - -import datetime -import time - -warm_up_times = 5 -repeat_times = 10 -sum_time = 0. - -for i in range(repeat_times): - paddle.device.synchronize() - starttime = datetime.datetime.now() - image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] - paddle.device.synchronize() - endtime = datetime.datetime.now() - duringtime = endtime - starttime - time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 - evet = "every_time: " + str(time_ms) + "ms\n\n" - with open("/cwb/wenbin/PaddleMIX/ppdiffusers/examples/inference/Aibin/time_729.txt", "a") as time_file: - time_file.write(evet) - sum_time+=time_ms -print("The ave end to end time : ", sum_time / repeat_times, "ms") -msg = "average_time: " + str(sum_time / repeat_times) + "ms\n\n" -print(msg) -with open("/cwb/wenbin/PaddleMIX/ppdiffusers/examples/inference/Aibin/time_729.txt", "a") as time_file: - time_file.write(msg) - -image.save("class_conditional_image_generation-dit-29.png") - +image.save("class_conditional_image_generation-dit-result.png") diff --git a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py index 56394d114..0ac7d69ee 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py @@ -8,8 +8,8 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio super().__init__() self.num_layers = num_layers self.dim = dim - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim + self.heads_num = num_attention_heads + self.head_dim = attention_head_dim self.timestep_embedder_in_channels = 256 self.timestep_embedder_time_embed_dim = 1152 self.timestep_embedder_time_embed_dim_out = self.timestep_embedder_time_embed_dim @@ -65,11 +65,11 @@ def forward(self, hidden_states, timesteps, class_labels): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) import paddlemix norm_hidden_states =paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_msa, shift_msa) - q = self.q[i](norm_hidden_states).reshape([norm_hidden_states.shape[0],norm_hidden_states.shape[1],self.num_attention_heads,self.attention_head_dim]) - k = self.k[i](norm_hidden_states).reshape([norm_hidden_states.shape[0],norm_hidden_states.shape[1],self.num_attention_heads,self.attention_head_dim]) - v = self.v[i](norm_hidden_states).reshape([norm_hidden_states.shape[0],norm_hidden_states.shape[1],self.num_attention_heads,self.attention_head_dim]) + q = self.q[i](norm_hidden_states).reshape([0,0,self.heads_num,self.head_dim]) + k = self.k[i](norm_hidden_states).reshape([0,0,self.heads_num,self.head_dim]) + v = self.v[i](norm_hidden_states).reshape([0,0,self.heads_num,self.head_dim]) - norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.attention_head_dim**-0.5) + norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.head_dim**-0.5) norm_hidden_states = norm_hidden_states.reshape([norm_hidden_states.shape[0],norm_hidden_states.shape[1],self.dim]) norm_hidden_states = self.out_proj[i](norm_hidden_states) diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index ab8bbc3bd..d1dc0ff29 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -220,7 +220,7 @@ def __init__( ] ) if self.Inference_Optimize: - self.simplified_facebookDIT = SimplifiedFacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) + self.simplified_facebookDIT = SimplifiedFacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels From 000dd80ce0eb5afb63e31b51ea26565b351968a5 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 7 Aug 2024 08:00:34 +0000 Subject: [PATCH 13/24] update Inference_Optimize --- .../inference/class_conditional_image_generation-dit.py | 8 ++------ ppdiffusers/ppdiffusers/models/transformer_2d.py | 4 ++-- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index d8929d6ad..554f35ac1 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -17,11 +17,7 @@ from paddlenlp.trainer import set_seed from ppdiffusers import DDIMScheduler, DiTPipeline -Inference_Optimize = True -if Inference_Optimize: - os.environ["Inference_Optimize"] = "True" -else: - pass +os.environ["Inference_Optimize"] = "True" dtype = paddle.float16 pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype) @@ -33,4 +29,4 @@ image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] -image.save("class_conditional_image_generation-dit-result.png") +image.save("class_conditional_image_generation-dit-result.png") \ No newline at end of file diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index d1dc0ff29..c6793019e 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -118,7 +118,7 @@ def __init__( self.inner_dim = inner_dim = num_attention_heads * attention_head_dim self.data_format = data_format - self.Inference_Optimize = bool(os.getenv('Inference_Optimize')) + self.Inference_Optimize = os.getenv('Inference_Optimize') == "True" conv_cls = nn.Conv2D if USE_PEFT_BACKEND else LoRACompatibleConv linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear @@ -497,7 +497,7 @@ def custom_forward(*inputs): @classmethod def custom_modify_weight(cls, state_dict): - if not os.getenv('Inference_Optimize'): + if os.getenv('Inference_Optimize') != "True": return map_from_my_dit = {} for i in range(28): From 9bb9cde142a4394838ec5bad61cc0967c85e042d Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 7 Aug 2024 08:10:56 +0000 Subject: [PATCH 14/24] update demo & simplified_facebook_dit --- .../class_conditional_image_generation-dit.py | 25 ++++++++++++++++++- .../models/simplified_facebook_dit.py | 20 +++++++-------- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index 554f35ac1..7a1c0322c 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -29,4 +29,27 @@ image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] -image.save("class_conditional_image_generation-dit-result.png") \ No newline at end of file +image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] +image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] + + +import datetime +import time +repeat_times = 10 +sum_time = 0. + +for i in range(repeat_times): + paddle.device.synchronize() + starttime = datetime.datetime.now() + + image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] + + paddle.device.synchronize() + endtime = datetime.datetime.now() + duringtime = endtime - starttime + + time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 + sum_time+=time_ms +print("The ave end to end time : ", sum_time / repeat_times, "ms") +image.save("class_conditional_image_generation-dit_last_result.png") + diff --git a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py index 0ac7d69ee..fab2e52ac 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py @@ -17,24 +17,24 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.LabelEmbedding_num_hidden_size = 1152 self.fcs0 = nn.LayerList([nn.Linear(self.timestep_embedder_in_channels, - self.timestep_embedder_time_embed_dim) for i in range(self.num_layers)]) + self.timestep_embedder_time_embed_dim) for i in range(num_layers)]) self.fcs1 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim, - self.timestep_embedder_time_embed_dim_out) for i in range(self.num_layers)]) + self.timestep_embedder_time_embed_dim_out) for i in range(num_layers)]) self.fcs2 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim, - 6 * self.timestep_embedder_time_embed_dim) for i in range(self.num_layers)]) + 6 * self.timestep_embedder_time_embed_dim) for i in range(num_layers)]) self.embs = nn.LayerList([nn.Embedding(self.LabelEmbedding_num_classes, - self.LabelEmbedding_num_hidden_size) for i in range(self.num_layers)]) + self.LabelEmbedding_num_hidden_size) for i in range(num_layers)]) - self.q = nn.LayerList([nn.Linear(dim, dim ) for i in range(self.num_layers)]) - self.k = nn.LayerList([nn.Linear(dim, dim ) for i in range(self.num_layers)]) - self.v = nn.LayerList([nn.Linear(dim, dim ) for i in range(self.num_layers)]) - self.out_proj = nn.LayerList([nn.Linear(dim, dim) for i in range(self.num_layers)]) - self.ffn1 = nn.LayerList([nn.Linear(dim, dim*4) for i in range(self.num_layers)]) - self.ffn2 = nn.LayerList([nn.Linear(dim*4, dim) for i in range(self.num_layers)]) + self.q = nn.LayerList([nn.Linear(dim, dim ) for i in range(num_layers)]) + self.k = nn.LayerList([nn.Linear(dim, dim ) for i in range(num_layers)]) + self.v = nn.LayerList([nn.Linear(dim, dim ) for i in range(num_layers)]) + self.out_proj = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)]) + self.ffn1 = nn.LayerList([nn.Linear(dim, dim*4) for i in range(num_layers)]) + self.ffn2 = nn.LayerList([nn.Linear(dim*4, dim) for i in range(num_layers)]) @paddle.incubate.jit.inference(enable_new_ir=True, cache_static_model=False, From d3de8382b344f6539b045e2173a6de99065c4d0c Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 7 Aug 2024 08:19:32 +0000 Subject: [PATCH 15/24] update demo --- .../class_conditional_image_generation-dit.py | 30 ++++++++----------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index 7a1c0322c..287f44789 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -27,29 +27,25 @@ words = ["golden retriever"] # class_ids [207] class_ids = pipe.get_label_ids(words) +# warmup +for i in range(5): + image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] -image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] -image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] -image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] - import datetime import time repeat_times = 10 -sum_time = 0. +paddle.device.synchronize() +starttime = datetime.datetime.now() for i in range(repeat_times): - paddle.device.synchronize() - starttime = datetime.datetime.now() - image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] - - paddle.device.synchronize() - endtime = datetime.datetime.now() - duringtime = endtime - starttime - - time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 - sum_time+=time_ms -print("The ave end to end time : ", sum_time / repeat_times, "ms") -image.save("class_conditional_image_generation-dit_last_result.png") + +paddle.device.synchronize() +endtime = datetime.datetime.now() +duringtime = endtime - starttime +time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 + +print("The ave end to end time : ", time_ms / repeat_times, "ms") +image.save("class_conditional_image_generation-dit-result.png") From 400ab196d17c788ec3b73a4689b1ffbc0aab85e9 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 7 Aug 2024 09:36:01 +0000 Subject: [PATCH 16/24] update demo simplified_facebook_dit transformer_2d --- .../inference/class_conditional_image_generation-dit.py | 2 +- ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py | 5 ----- ppdiffusers/ppdiffusers/models/transformer_2d.py | 6 ++++++ 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index 287f44789..2cd401857 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -17,7 +17,7 @@ from paddlenlp.trainer import set_seed from ppdiffusers import DDIMScheduler, DiTPipeline -os.environ["Inference_Optimize"] = "True" +os.environ["Inference_Optimize"] = "False" dtype = paddle.float16 pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype) diff --git a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py index fab2e52ac..dd418a9f2 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py @@ -36,11 +36,6 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.ffn1 = nn.LayerList([nn.Linear(dim, dim*4) for i in range(num_layers)]) self.ffn2 = nn.LayerList([nn.Linear(dim*4, dim) for i in range(num_layers)]) - @paddle.incubate.jit.inference(enable_new_ir=True, - cache_static_model=False, - exp_enable_use_cutlass=True, - delete_pass_lists=["add_norm_fuse_pass"], - ) def forward(self, hidden_states, timesteps, class_labels): # below code are copied from PaddleMIX/ppdiffusers/ppdiffusers/models/embeddings.py diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index c6793019e..49a0c39a7 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -221,6 +221,12 @@ def __init__( ) if self.Inference_Optimize: self.simplified_facebookDIT = SimplifiedFacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) + self.simplified_facebookDIT = paddle.incubate.jit.inference(self.simplified_facebookDIT, + enable_new_ir=True, + cache_static_model=False, + exp_enable_use_cutlass=True, + delete_pass_lists=["add_norm_fuse_pass"], + ) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels From bfe8c413d3b421e914859c2bc4a1021772a7481f Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 7 Aug 2024 11:04:06 +0000 Subject: [PATCH 17/24] update demo transformer_2d & simplified_facebook_dit --- .../class_conditional_image_generation-dit.py | 27 ++------- .../models/simplified_facebook_dit.py | 57 ++++++++++--------- .../ppdiffusers/models/transformer_2d.py | 36 ++++++------ 3 files changed, 55 insertions(+), 65 deletions(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index 2cd401857..8445f6c6b 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -17,8 +17,8 @@ from paddlenlp.trainer import set_seed from ppdiffusers import DDIMScheduler, DiTPipeline -os.environ["Inference_Optimize"] = "False" - +os.environ["INFOPTIMIZE"] = "False" + dtype = paddle.float16 pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) @@ -27,25 +27,6 @@ words = ["golden retriever"] # class_ids [207] class_ids = pipe.get_label_ids(words) -# warmup -for i in range(5): - image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] - - -import datetime -import time -repeat_times = 10 -paddle.device.synchronize() -starttime = datetime.datetime.now() - -for i in range(repeat_times): - image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] - -paddle.device.synchronize() -endtime = datetime.datetime.now() -duringtime = endtime - starttime -time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 - -print("The ave end to end time : ", time_ms / repeat_times, "ms") -image.save("class_conditional_image_generation-dit-result.png") +image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] +image.save("class_conditional_image_generation-dit-result.png") \ No newline at end of file diff --git a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py index dd418a9f2..6247c8d45 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py @@ -3,6 +3,7 @@ import paddle.nn.functional as F import math + class SimplifiedFacebookDIT(nn.Layer): def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): super().__init__() @@ -15,29 +16,28 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.timestep_embedder_time_embed_dim_out = self.timestep_embedder_time_embed_dim self.LabelEmbedding_num_classes = 1001 self.LabelEmbedding_num_hidden_size = 1152 - - self.fcs0 = nn.LayerList([nn.Linear(self.timestep_embedder_in_channels, + + self.fcs0 = nn.LayerList([nn.Linear(self.timestep_embedder_in_channels, self.timestep_embedder_time_embed_dim) for i in range(num_layers)]) - + self.fcs1 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim, self.timestep_embedder_time_embed_dim_out) for i in range(num_layers)]) - + self.fcs2 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim, 6 * self.timestep_embedder_time_embed_dim) for i in range(num_layers)]) - - self.embs = nn.LayerList([nn.Embedding(self.LabelEmbedding_num_classes, + + self.embs = nn.LayerList([nn.Embedding(self.LabelEmbedding_num_classes, self.LabelEmbedding_num_hidden_size) for i in range(num_layers)]) - - self.q = nn.LayerList([nn.Linear(dim, dim ) for i in range(num_layers)]) - self.k = nn.LayerList([nn.Linear(dim, dim ) for i in range(num_layers)]) - self.v = nn.LayerList([nn.Linear(dim, dim ) for i in range(num_layers)]) + self.q = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)]) + self.k = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)]) + self.v = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)]) self.out_proj = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)]) - self.ffn1 = nn.LayerList([nn.Linear(dim, dim*4) for i in range(num_layers)]) - self.ffn2 = nn.LayerList([nn.Linear(dim*4, dim) for i in range(num_layers)]) + self.ffn1 = nn.LayerList([nn.Linear(dim, dim * 4) for i in range(num_layers)]) + self.ffn2 = nn.LayerList([nn.Linear(dim * 4, dim) for i in range(num_layers)]) def forward(self, hidden_states, timesteps, class_labels): - + # below code are copied from PaddleMIX/ppdiffusers/ppdiffusers/models/embeddings.py num_channels = 256 max_period = 10000 @@ -49,8 +49,8 @@ def forward(self, hidden_states, timesteps, class_labels): emb = timesteps[:, None].cast("float32") * emb[None, :] emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1) common_emb = emb.cast(hidden_states.dtype) - - for i in range(self.num_layers): + + for i in range(self.num_layers): emb = self.fcs0[i](common_emb) emb = F.silu(emb) emb = self.fcs1[i](emb) @@ -59,24 +59,29 @@ def forward(self, hidden_states, timesteps, class_labels): emb = self.fcs2[i](emb) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) import paddlemix - norm_hidden_states =paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_msa, shift_msa) - q = self.q[i](norm_hidden_states).reshape([0,0,self.heads_num,self.head_dim]) - k = self.k[i](norm_hidden_states).reshape([0,0,self.heads_num,self.head_dim]) - v = self.v[i](norm_hidden_states).reshape([0,0,self.heads_num,self.head_dim]) + norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_msa, shift_msa) + q = self.q[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim]) + k = self.k[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim]) + v = self.v[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim]) norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.head_dim**-0.5) - norm_hidden_states = norm_hidden_states.reshape([norm_hidden_states.shape[0],norm_hidden_states.shape[1],self.dim]) + norm_hidden_states = norm_hidden_states.reshape([norm_hidden_states.shape[0], norm_hidden_states.shape[1], self.dim]) norm_hidden_states = self.out_proj[i](norm_hidden_states) - - # hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([b,1,self.dim]) + + # hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([b,1,self.dim]) # norm_hidden_states =paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp) - hidden_states,norm_hidden_states =paddlemix.triton_ops.fused_adaLN_scale_residual(hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp) + hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( + hidden_states, + norm_hidden_states, + gate_msa, + scale_mlp, + shift_mlp + ) norm_hidden_states = self.ffn1[i](norm_hidden_states) norm_hidden_states = F.gelu(norm_hidden_states, approximate=True) norm_hidden_states = self.ffn2[i](norm_hidden_states) - hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([norm_hidden_states.shape[0],1,self.dim]) - + hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([norm_hidden_states.shape[0], 1, self.dim]) + return hidden_states - \ No newline at end of file diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index 49a0c39a7..58c8b064c 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -38,7 +38,6 @@ import os - @dataclass class Transformer2DModelOutput(BaseOutput): """ @@ -118,8 +117,8 @@ def __init__( self.inner_dim = inner_dim = num_attention_heads * attention_head_dim self.data_format = data_format - self.Inference_Optimize = os.getenv('Inference_Optimize') == "True" - + self.inference_optimize = os.getenv('INFOPTIMIZE') == "True" + conv_cls = nn.Conv2D if USE_PEFT_BACKEND else LoRACompatibleConv linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear @@ -219,14 +218,20 @@ def __init__( for d in range(num_layers) ] ) - if self.Inference_Optimize: - self.simplified_facebookDIT = SimplifiedFacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) - self.simplified_facebookDIT = paddle.incubate.jit.inference(self.simplified_facebookDIT, - enable_new_ir=True, - cache_static_model=False, - exp_enable_use_cutlass=True, - delete_pass_lists=["add_norm_fuse_pass"], - ) + if self.inference_optimize: + self.simplified_facebookDIT = SimplifiedFacebookDIT( + num_layers, + inner_dim, + num_attention_heads, + attention_head_dim + ) + self.simplified_facebookDIT = paddle.incubate.jit.inference( + self.simplified_facebookDIT, + enable_new_ir=True, + cache_static_model=False, + exp_enable_use_cutlass=True, + delete_pass_lists=["add_norm_fuse_pass"], + ) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels @@ -264,7 +269,6 @@ def __init__( self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim) self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): @@ -399,9 +403,9 @@ def forward( batch_size = hidden_states.shape[0] encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]]) - - if self.Inference_Optimize: - hidden_states =self.simplified_facebookDIT(hidden_states, timestep, class_labels) + + if self.inference_optimize: + hidden_states = self.simplified_facebookDIT(hidden_states, timestep, class_labels) else: for block in self.transformer_blocks: if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): @@ -503,7 +507,7 @@ def custom_forward(*inputs): @classmethod def custom_modify_weight(cls, state_dict): - if os.getenv('Inference_Optimize') != "True": + if os.getenv('INFOPTIMIZE') != "True": return map_from_my_dit = {} for i in range(28): From 8896057f63ffeb72cba267afff6bb28e782adae1 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 7 Aug 2024 11:35:18 +0000 Subject: [PATCH 18/24] test --- .../inference/class_conditional_image_generation-dit.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index 8445f6c6b..f540683a8 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -13,8 +13,10 @@ # limitations under the License. import os + import paddle from paddlenlp.trainer import set_seed + from ppdiffusers import DDIMScheduler, DiTPipeline os.environ["INFOPTIMIZE"] = "False" @@ -27,6 +29,5 @@ words = ["golden retriever"] # class_ids [207] class_ids = pipe.get_label_ids(words) - image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] -image.save("class_conditional_image_generation-dit-result.png") \ No newline at end of file +image.save("class_conditional_image_generation-dit-result.png") From e9aa47dd4410b34686e62ee6eceac8cb9770d6a3 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 7 Aug 2024 11:42:40 +0000 Subject: [PATCH 19/24] add format --- .../models/simplified_facebook_dit.py | 67 ++++++++++---- .../ppdiffusers/models/transformer_2d.py | 89 +++++++++++-------- 2 files changed, 101 insertions(+), 55 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py index 6247c8d45..dfe64c199 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py @@ -1,7 +1,22 @@ -from paddle import nn +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + import paddle import paddle.nn.functional as F -import math +from paddle import nn class SimplifiedFacebookDIT(nn.Layer): @@ -17,17 +32,33 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.LabelEmbedding_num_classes = 1001 self.LabelEmbedding_num_hidden_size = 1152 - self.fcs0 = nn.LayerList([nn.Linear(self.timestep_embedder_in_channels, - self.timestep_embedder_time_embed_dim) for i in range(num_layers)]) + self.fcs0 = nn.LayerList( + [ + nn.Linear(self.timestep_embedder_in_channels, self.timestep_embedder_time_embed_dim) + for i in range(num_layers) + ] + ) - self.fcs1 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim, - self.timestep_embedder_time_embed_dim_out) for i in range(num_layers)]) + self.fcs1 = nn.LayerList( + [ + nn.Linear(self.timestep_embedder_time_embed_dim, self.timestep_embedder_time_embed_dim_out) + for i in range(num_layers) + ] + ) - self.fcs2 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim, - 6 * self.timestep_embedder_time_embed_dim) for i in range(num_layers)]) + self.fcs2 = nn.LayerList( + [ + nn.Linear(self.timestep_embedder_time_embed_dim, 6 * self.timestep_embedder_time_embed_dim) + for i in range(num_layers) + ] + ) - self.embs = nn.LayerList([nn.Embedding(self.LabelEmbedding_num_classes, - self.LabelEmbedding_num_hidden_size) for i in range(num_layers)]) + self.embs = nn.LayerList( + [ + nn.Embedding(self.LabelEmbedding_num_classes, self.LabelEmbedding_num_hidden_size) + for i in range(num_layers) + ] + ) self.q = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)]) self.k = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)]) @@ -59,29 +90,29 @@ def forward(self, hidden_states, timesteps, class_labels): emb = self.fcs2[i](emb) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) import paddlemix + norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_msa, shift_msa) q = self.q[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim]) k = self.k[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim]) v = self.v[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim]) norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.head_dim**-0.5) - norm_hidden_states = norm_hidden_states.reshape([norm_hidden_states.shape[0], norm_hidden_states.shape[1], self.dim]) + norm_hidden_states = norm_hidden_states.reshape( + [norm_hidden_states.shape[0], norm_hidden_states.shape[1], self.dim] + ) norm_hidden_states = self.out_proj[i](norm_hidden_states) # hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([b,1,self.dim]) # norm_hidden_states =paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp) hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( - hidden_states, - norm_hidden_states, - gate_msa, - scale_mlp, - shift_mlp + hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp ) - norm_hidden_states = self.ffn1[i](norm_hidden_states) norm_hidden_states = F.gelu(norm_hidden_states, approximate=True) norm_hidden_states = self.ffn2[i](norm_hidden_states) - hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([norm_hidden_states.shape[0], 1, self.dim]) + hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape( + [norm_hidden_states.shape[0], 1, self.dim] + ) return hidden_states diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index 58c8b064c..71d31c648 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from dataclasses import dataclass from typing import Any, Dict, Optional @@ -28,14 +29,12 @@ recompute_use_reentrant, use_old_recompute, ) -from .simplified_facebook_dit import SimplifiedFacebookDIT - from .attention import BasicTransformerBlock from .embeddings import CaptionProjection, PatchEmbed from .lora import LoRACompatibleConv, LoRACompatibleLinear from .modeling_utils import ModelMixin from .normalization import AdaLayerNormSingle -import os +from .simplified_facebook_dit import SimplifiedFacebookDIT @dataclass @@ -117,7 +116,7 @@ def __init__( self.inner_dim = inner_dim = num_attention_heads * attention_head_dim self.data_format = data_format - self.inference_optimize = os.getenv('INFOPTIMIZE') == "True" + self.inference_optimize = os.getenv("INFOPTIMIZE") == "True" conv_cls = nn.Conv2D if USE_PEFT_BACKEND else LoRACompatibleConv linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear @@ -159,11 +158,15 @@ def __init__( if self.is_input_continuous: self.in_channels = in_channels - self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, epsilon=1e-6, data_format=data_format) + self.norm = nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, epsilon=1e-6, data_format=data_format + ) if use_linear_projection: self.proj_in = linear_cls(in_channels, inner_dim) else: - self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0, data_format=data_format) + self.proj_in = conv_cls( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0, data_format=data_format + ) elif self.is_input_vectorized: assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" @@ -219,14 +222,11 @@ def __init__( ] ) if self.inference_optimize: - self.simplified_facebookDIT = SimplifiedFacebookDIT( - num_layers, - inner_dim, - num_attention_heads, - attention_head_dim + self.simplified_facebookdit = SimplifiedFacebookDIT( + num_layers, inner_dim, num_attention_heads, attention_head_dim ) - self.simplified_facebookDIT = paddle.incubate.jit.inference( - self.simplified_facebookDIT, + self.simplified_facebookdit = paddle.incubate.jit.inference( + self.simplified_facebookdit, enable_new_ir=True, cache_static_model=False, exp_enable_use_cutlass=True, @@ -240,7 +240,9 @@ def __init__( if use_linear_projection: self.proj_out = linear_cls(inner_dim, in_channels) else: - self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0, data_format=data_format) + self.proj_out = conv_cls( + inner_dim, in_channels, kernel_size=1, stride=1, padding=0, data_format=data_format + ) elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) @@ -405,7 +407,7 @@ def forward( encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]]) if self.inference_optimize: - hidden_states = self.simplified_facebookDIT(hidden_states, timestep, class_labels) + hidden_states = self.simplified_facebookdit(hidden_states, timestep, class_labels) else: for block in self.transformer_blocks: if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): @@ -507,31 +509,44 @@ def custom_forward(*inputs): @classmethod def custom_modify_weight(cls, state_dict): - if os.getenv('INFOPTIMIZE') != "True": + if os.getenv("INFOPTIMIZE") != "True": return map_from_my_dit = {} for i in range(28): - map_from_my_dit[f'simplified_facebookDIT.q.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_q.weight' - map_from_my_dit[f'simplified_facebookDIT.k.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_k.weight' - map_from_my_dit[f'simplified_facebookDIT.v.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_v.weight' - map_from_my_dit[f'simplified_facebookDIT.q.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_q.bias' - map_from_my_dit[f'simplified_facebookDIT.k.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_k.bias' - map_from_my_dit[f'simplified_facebookDIT.v.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_v.bias' - map_from_my_dit[f'simplified_facebookDIT.out_proj.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_out.0.weight' - map_from_my_dit[f'simplified_facebookDIT.out_proj.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_out.0.bias' - map_from_my_dit[f'simplified_facebookDIT.ffn1.{i}.weight'] = f'transformer_blocks.{i}.ff.net.0.proj.weight' - map_from_my_dit[f'simplified_facebookDIT.ffn1.{i}.bias'] = f'transformer_blocks.{i}.ff.net.0.proj.bias' - map_from_my_dit[f'simplified_facebookDIT.ffn2.{i}.weight'] = f'transformer_blocks.{i}.ff.net.2.weight' - map_from_my_dit[f'simplified_facebookDIT.ffn2.{i}.bias'] = f'transformer_blocks.{i}.ff.net.2.bias' - - map_from_my_dit[f'simplified_facebookDIT.fcs0.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight' - map_from_my_dit[f'simplified_facebookDIT.fcs0.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias' - map_from_my_dit[f'simplified_facebookDIT.fcs1.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight' - map_from_my_dit[f'simplified_facebookDIT.fcs1.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias' - map_from_my_dit[f'simplified_facebookDIT.fcs2.{i}.weight'] = f'transformer_blocks.{i}.norm1.linear.weight' - map_from_my_dit[f'simplified_facebookDIT.fcs2.{i}.bias'] = f'transformer_blocks.{i}.norm1.linear.bias' - - map_from_my_dit[f'simplified_facebookDIT.embs.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight' + map_from_my_dit[f"simplified_facebookdit.q.{i}.weight"] = f"transformer_blocks.{i}.attn1.to_q.weight" + map_from_my_dit[f"simplified_facebookdit.k.{i}.weight"] = f"transformer_blocks.{i}.attn1.to_k.weight" + map_from_my_dit[f"simplified_facebookdit.v.{i}.weight"] = f"transformer_blocks.{i}.attn1.to_v.weight" + map_from_my_dit[f"simplified_facebookdit.q.{i}.bias"] = f"transformer_blocks.{i}.attn1.to_q.bias" + map_from_my_dit[f"simplified_facebookdit.k.{i}.bias"] = f"transformer_blocks.{i}.attn1.to_k.bias" + map_from_my_dit[f"simplified_facebookdit.v.{i}.bias"] = f"transformer_blocks.{i}.attn1.to_v.bias" + map_from_my_dit[ + f"simplified_facebookdit.out_proj.{i}.weight" + ] = f"transformer_blocks.{i}.attn1.to_out.0.weight" + map_from_my_dit[ + f"simplified_facebookdit.out_proj.{i}.bias" + ] = f"transformer_blocks.{i}.attn1.to_out.0.bias" + map_from_my_dit[f"simplified_facebookdit.ffn1.{i}.weight"] = f"transformer_blocks.{i}.ff.net.0.proj.weight" + map_from_my_dit[f"simplified_facebookdit.ffn1.{i}.bias"] = f"transformer_blocks.{i}.ff.net.0.proj.bias" + map_from_my_dit[f"simplified_facebookdit.ffn2.{i}.weight"] = f"transformer_blocks.{i}.ff.net.2.weight" + map_from_my_dit[f"simplified_facebookdit.ffn2.{i}.bias"] = f"transformer_blocks.{i}.ff.net.2.bias" + + map_from_my_dit[ + f"simplified_facebookdit.fcs0.{i}.weight" + ] = f"transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight" + map_from_my_dit[ + f"simplified_facebookdit.fcs0.{i}.bias" + ] = f"transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias" + map_from_my_dit[ + f"simplified_facebookdit.fcs1.{i}.weight" + ] = f"transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight" + map_from_my_dit[ + f"simplified_facebookdit.fcs1.{i}.bias" + ] = f"transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias" + map_from_my_dit[f"simplified_facebookdit.fcs2.{i}.weight"] = f"transformer_blocks.{i}.norm1.linear.weight" + map_from_my_dit[f"simplified_facebookdit.fcs2.{i}.bias"] = f"transformer_blocks.{i}.norm1.linear.bias" + map_from_my_dit[ + f"simplified_facebookdit.embs.{i}.weight" + ] = f"transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight" for key in map_from_my_dit.keys(): state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]]) From c8916f7059877c042e5e1ba34094987c9612a24c Mon Sep 17 00:00:00 2001 From: changwenbin Date: Wed, 7 Aug 2024 11:52:17 +0000 Subject: [PATCH 20/24] add format --- ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py index dfe64c199..2a7104b24 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py @@ -102,8 +102,6 @@ def forward(self, hidden_states, timesteps, class_labels): ) norm_hidden_states = self.out_proj[i](norm_hidden_states) - # hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([b,1,self.dim]) - # norm_hidden_states =paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp) hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp ) From a87f81bd1dd4a1ff04f89284cd8bdfa64c536b42 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 8 Aug 2024 11:45:19 +0000 Subject: [PATCH 21/24] add Argument to the demo --- .../class_conditional_image_generation-dit.py | 60 ++++++++++++++++++- .../models/simplified_facebook_dit.py | 35 +++++++++-- .../ppdiffusers/models/transformer_2d.py | 6 +- 3 files changed, 92 insertions(+), 9 deletions(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index f540683a8..004d9a851 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import os import paddle @@ -19,7 +20,46 @@ from ppdiffusers import DDIMScheduler, DiTPipeline -os.environ["INFOPTIMIZE"] = "False" + +def parse_args(): + parser = argparse.ArgumentParser( + description=" Use PaddleMIX to accelerate the Diffusion Transformer image generation model." + ) + parser.add_argument( + "--benchmark", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=False, + help="if benchmark is set to True, measure inference performance", + ), + parser.add_argument( + "--inference_optimize", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=False, + help="If inference_optimize is set to True, all optimizations except Triton are enabled.", + ), + parser.add_argument( + "--inference_optimize_triton_an", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=True, + help="If inference_optimize_triton_an is set to True, the Triton optimization operator 'adaptive_layer_norm' is enabled.", + ), + parser.add_argument( + "--inference_optimize_triton_asr", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=True, + help="If inference_optimize_triton_an is set to True, the Triton optimization operator 'fused_adaLN_scale_residual' is enabled.", + ) + return parser.parse_args() + + +args = parse_args() + +if args.inference_optimize: + os.environ["INFERENCE_OPTIMIZE"] = "True" +if args.inference_optimize_triton_an: + os.environ["INFERENCE_OPTIMIZE_TRITON_AN"] = "True" +if args.inference_optimize_triton_asr: + os.environ["INFERENCE_OPTIMIZE_TRITON_ASR"] = "True" dtype = paddle.float16 pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype) @@ -28,6 +68,22 @@ words = ["golden retriever"] # class_ids [207] class_ids = pipe.get_label_ids(words) - image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] + +if args.benchmark: + import datetime + + # warmup + for i in range(5): + image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] + repeat_times = 5 + paddle.device.synchronize() + starttime = datetime.datetime.now() + for i in range(repeat_times): + image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] + paddle.device.synchronize() + endtime = datetime.datetime.now() + duringtime = endtime - starttime + time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 + print("The ave end to end time : ", time_ms / repeat_times, "ms") image.save("class_conditional_image_generation-dit-result.png") diff --git a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py index 2a7104b24..700a13c24 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py @@ -13,6 +13,7 @@ # limitations under the License. import math +import os import paddle import paddle.nn.functional as F @@ -66,6 +67,8 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.out_proj = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)]) self.ffn1 = nn.LayerList([nn.Linear(dim, dim * 4) for i in range(num_layers)]) self.ffn2 = nn.LayerList([nn.Linear(dim * 4, dim) for i in range(num_layers)]) + self.norm = nn.LayerNorm(1152, epsilon=1e-06, weight_attr=False, bias_attr=False) + self.norm1 = nn.LayerNorm(1152, epsilon=1e-05, weight_attr=False, bias_attr=False) def forward(self, hidden_states, timesteps, class_labels): @@ -89,9 +92,19 @@ def forward(self, hidden_states, timesteps, class_labels): emb = F.silu(emb) emb = self.fcs2[i](emb) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) - import paddlemix - norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_msa, shift_msa) + if os.getenv("INFERENCE_OPTIMIZE_TRITON_AN"): + import paddlemix + + norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( + hidden_states, scale_msa, shift_msa, epsilon=1e-06 + ) + else: + norm_hidden_states = self.norm( + hidden_states, + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None] + q = self.q[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim]) k = self.k[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim]) v = self.v[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim]) @@ -102,9 +115,21 @@ def forward(self, hidden_states, timesteps, class_labels): ) norm_hidden_states = self.out_proj[i](norm_hidden_states) - hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( - hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp - ) + if os.getenv("INFERENCE_OPTIMIZE_TRITON_ASR"): + import paddlemix + + hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( + hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp, epsilon=1e-05 + ) + else: + hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape( + [norm_hidden_states.shape[0], 1, self.dim] + ) + norm_hidden_states = self.norm1( + hidden_states, + ) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_hidden_states = self.ffn1[i](norm_hidden_states) norm_hidden_states = F.gelu(norm_hidden_states, approximate=True) norm_hidden_states = self.ffn2[i](norm_hidden_states) diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index 71d31c648..b78ada3e4 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -116,7 +116,7 @@ def __init__( self.inner_dim = inner_dim = num_attention_heads * attention_head_dim self.data_format = data_format - self.inference_optimize = os.getenv("INFOPTIMIZE") == "True" + self.inference_optimize = os.getenv("INFERENCE_OPTIMIZE") == "True" conv_cls = nn.Conv2D if USE_PEFT_BACKEND else LoRACompatibleConv linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear @@ -407,7 +407,9 @@ def forward( encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]]) if self.inference_optimize: + paddle.device.synchronize() hidden_states = self.simplified_facebookdit(hidden_states, timestep, class_labels) + paddle.device.synchronize() else: for block in self.transformer_blocks: if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): @@ -509,7 +511,7 @@ def custom_forward(*inputs): @classmethod def custom_modify_weight(cls, state_dict): - if os.getenv("INFOPTIMIZE") != "True": + if os.getenv("INFERENCE_OPTIMIZE") != "True": return map_from_my_dit = {} for i in range(28): From 0a09bf27ad9408f1f664a975448af21de5d0e13e Mon Sep 17 00:00:00 2001 From: changwenbin Date: Thu, 8 Aug 2024 12:16:20 +0000 Subject: [PATCH 22/24] update Argument to the demo --- .../class_conditional_image_generation-dit.py | 20 ++++++------------- .../models/simplified_facebook_dit.py | 5 ++--- .../ppdiffusers/models/transformer_2d.py | 2 -- 3 files changed, 8 insertions(+), 19 deletions(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index 004d9a851..8e2d26c55 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -30,24 +30,18 @@ def parse_args(): type=(lambda x: str(x).lower() in ["true", "1", "yes"]), default=False, help="if benchmark is set to True, measure inference performance", - ), + ) parser.add_argument( "--inference_optimize", type=(lambda x: str(x).lower() in ["true", "1", "yes"]), default=False, help="If inference_optimize is set to True, all optimizations except Triton are enabled.", - ), - parser.add_argument( - "--inference_optimize_triton_an", - type=(lambda x: str(x).lower() in ["true", "1", "yes"]), - default=True, - help="If inference_optimize_triton_an is set to True, the Triton optimization operator 'adaptive_layer_norm' is enabled.", - ), + ) parser.add_argument( - "--inference_optimize_triton_asr", + "--inference_optimize_triton", type=(lambda x: str(x).lower() in ["true", "1", "yes"]), default=True, - help="If inference_optimize_triton_an is set to True, the Triton optimization operator 'fused_adaLN_scale_residual' is enabled.", + help="If inference_optimize_triton is set to True, Triton operator optimized inference is enabled.", ) return parser.parse_args() @@ -56,10 +50,8 @@ def parse_args(): if args.inference_optimize: os.environ["INFERENCE_OPTIMIZE"] = "True" -if args.inference_optimize_triton_an: - os.environ["INFERENCE_OPTIMIZE_TRITON_AN"] = "True" -if args.inference_optimize_triton_asr: - os.environ["INFERENCE_OPTIMIZE_TRITON_ASR"] = "True" +if args.inference_optimize_triton: + os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True" dtype = paddle.float16 pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype) diff --git a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py index 700a13c24..03339d493 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py @@ -93,7 +93,7 @@ def forward(self, hidden_states, timesteps, class_labels): emb = self.fcs2[i](emb) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) - if os.getenv("INFERENCE_OPTIMIZE_TRITON_AN"): + if os.getenv("INFERENCE_OPTIMIZE_TRITON"): import paddlemix norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( @@ -114,8 +114,7 @@ def forward(self, hidden_states, timesteps, class_labels): [norm_hidden_states.shape[0], norm_hidden_states.shape[1], self.dim] ) norm_hidden_states = self.out_proj[i](norm_hidden_states) - - if os.getenv("INFERENCE_OPTIMIZE_TRITON_ASR"): + if os.getenv("INFERENCE_OPTIMIZE_TRITON"): import paddlemix hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index b78ada3e4..8dd57b158 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -407,9 +407,7 @@ def forward( encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]]) if self.inference_optimize: - paddle.device.synchronize() hidden_states = self.simplified_facebookdit(hidden_states, timestep, class_labels) - paddle.device.synchronize() else: for block in self.transformer_blocks: if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): From 10953b5e51782942e9d6881e8693f3b5b1baeb7d Mon Sep 17 00:00:00 2001 From: changwenbin Date: Fri, 9 Aug 2024 05:09:32 +0000 Subject: [PATCH 23/24] update transformer_2d --- .../class_conditional_image_generation-dit.py | 4 ++ .../models/simplified_facebook_dit.py | 5 +- .../ppdiffusers/models/transformer_2d.py | 61 +++++++------------ 3 files changed, 28 insertions(+), 42 deletions(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index 8e2d26c55..963a2684a 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -68,14 +68,18 @@ def parse_args(): # warmup for i in range(5): image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] + repeat_times = 5 + paddle.device.synchronize() starttime = datetime.datetime.now() for i in range(repeat_times): image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] paddle.device.synchronize() endtime = datetime.datetime.now() + duringtime = endtime - starttime time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 print("The ave end to end time : ", time_ms / repeat_times, "ms") + image.save("class_conditional_image_generation-dit-result.png") diff --git a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py index 03339d493..2a1fde485 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py +++ b/ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py @@ -92,10 +92,9 @@ def forward(self, hidden_states, timesteps, class_labels): emb = F.silu(emb) emb = self.fcs2[i](emb) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) + import paddlemix if os.getenv("INFERENCE_OPTIMIZE_TRITON"): - import paddlemix - norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm( hidden_states, scale_msa, shift_msa, epsilon=1e-06 ) @@ -115,8 +114,6 @@ def forward(self, hidden_states, timesteps, class_labels): ) norm_hidden_states = self.out_proj[i](norm_hidden_states) if os.getenv("INFERENCE_OPTIMIZE_TRITON"): - import paddlemix - hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp, epsilon=1e-05 ) diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index 8dd57b158..dd6183d26 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -511,42 +511,27 @@ def custom_forward(*inputs): def custom_modify_weight(cls, state_dict): if os.getenv("INFERENCE_OPTIMIZE") != "True": return - map_from_my_dit = {} for i in range(28): - map_from_my_dit[f"simplified_facebookdit.q.{i}.weight"] = f"transformer_blocks.{i}.attn1.to_q.weight" - map_from_my_dit[f"simplified_facebookdit.k.{i}.weight"] = f"transformer_blocks.{i}.attn1.to_k.weight" - map_from_my_dit[f"simplified_facebookdit.v.{i}.weight"] = f"transformer_blocks.{i}.attn1.to_v.weight" - map_from_my_dit[f"simplified_facebookdit.q.{i}.bias"] = f"transformer_blocks.{i}.attn1.to_q.bias" - map_from_my_dit[f"simplified_facebookdit.k.{i}.bias"] = f"transformer_blocks.{i}.attn1.to_k.bias" - map_from_my_dit[f"simplified_facebookdit.v.{i}.bias"] = f"transformer_blocks.{i}.attn1.to_v.bias" - map_from_my_dit[ - f"simplified_facebookdit.out_proj.{i}.weight" - ] = f"transformer_blocks.{i}.attn1.to_out.0.weight" - map_from_my_dit[ - f"simplified_facebookdit.out_proj.{i}.bias" - ] = f"transformer_blocks.{i}.attn1.to_out.0.bias" - map_from_my_dit[f"simplified_facebookdit.ffn1.{i}.weight"] = f"transformer_blocks.{i}.ff.net.0.proj.weight" - map_from_my_dit[f"simplified_facebookdit.ffn1.{i}.bias"] = f"transformer_blocks.{i}.ff.net.0.proj.bias" - map_from_my_dit[f"simplified_facebookdit.ffn2.{i}.weight"] = f"transformer_blocks.{i}.ff.net.2.weight" - map_from_my_dit[f"simplified_facebookdit.ffn2.{i}.bias"] = f"transformer_blocks.{i}.ff.net.2.bias" - - map_from_my_dit[ - f"simplified_facebookdit.fcs0.{i}.weight" - ] = f"transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight" - map_from_my_dit[ - f"simplified_facebookdit.fcs0.{i}.bias" - ] = f"transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias" - map_from_my_dit[ - f"simplified_facebookdit.fcs1.{i}.weight" - ] = f"transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight" - map_from_my_dit[ - f"simplified_facebookdit.fcs1.{i}.bias" - ] = f"transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias" - map_from_my_dit[f"simplified_facebookdit.fcs2.{i}.weight"] = f"transformer_blocks.{i}.norm1.linear.weight" - map_from_my_dit[f"simplified_facebookdit.fcs2.{i}.bias"] = f"transformer_blocks.{i}.norm1.linear.bias" - map_from_my_dit[ - f"simplified_facebookdit.embs.{i}.weight" - ] = f"transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight" - - for key in map_from_my_dit.keys(): - state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]]) + map_from_my_dit = [ + (f"q.{i}.weight", f"{i}.attn1.to_q.weight"), + (f"k.{i}.weight", f"{i}.attn1.to_k.weight"), + (f"v.{i}.weight", f"{i}.attn1.to_v.weight"), + (f"q.{i}.bias", f"{i}.attn1.to_q.bias"), + (f"k.{i}.bias", f"{i}.attn1.to_k.bias"), + (f"v.{i}.bias", f"{i}.attn1.to_v.bias"), + (f"out_proj.{i}.weight", f"{i}.attn1.to_out.0.weight"), + (f"out_proj.{i}.bias", f"{i}.attn1.to_out.0.bias"), + (f"ffn1.{i}.weight", f"{i}.ff.net.0.proj.weight"), + (f"ffn1.{i}.bias", f"{i}.ff.net.0.proj.bias"), + (f"ffn2.{i}.weight", f"{i}.ff.net.2.weight"), + (f"ffn2.{i}.bias", f"{i}.ff.net.2.bias"), + (f"fcs0.{i}.weight", f"{i}.norm1.emb.timestep_embedder.linear_1.weight"), + (f"fcs0.{i}.bias", f"{i}.norm1.emb.timestep_embedder.linear_1.bias"), + (f"fcs1.{i}.weight", f"{i}.norm1.emb.timestep_embedder.linear_2.weight"), + (f"fcs1.{i}.bias", f"{i}.norm1.emb.timestep_embedder.linear_2.bias"), + (f"fcs2.{i}.weight", f"{i}.norm1.linear.weight"), + (f"fcs2.{i}.bias", f"{i}.norm1.linear.bias"), + (f"embs.{i}.weight", f"{i}.norm1.emb.class_embedder.embedding_table.weight"), + ] + for to_, from_ in map_from_my_dit: + state_dict["simplified_facebookdit." + to_] = paddle.assign(state_dict["transformer_blocks." + from_]) From 922d7d0994afacdbf6e2b2c839f0eff53b86e023 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Mon, 19 Aug 2024 03:38:26 +0000 Subject: [PATCH 24/24] update DIT_demo --- .../inference/class_conditional_image_generation-dit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index 963a2684a..84f6c7d8f 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -13,6 +13,7 @@ # limitations under the License. import argparse +import datetime import os import paddle @@ -63,7 +64,6 @@ def parse_args(): image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] if args.benchmark: - import datetime # warmup for i in range(5):