Skip to content

Re-network the DIT, fix some parameters, and simplify the model networking code #632

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
59f23a0
modified the dit
chang-wenbin Jul 29, 2024
5fee64b
add zkk_facebook
chang-wenbin Jul 29, 2024
f653a66
update zkk_facebook_dit.py
chang-wenbin Jul 29, 2024
3b29d9d
update transformer_2d
chang-wenbin Jul 30, 2024
a88caea
update dit optimize
chang-wenbin Jul 31, 2024
54eeec2
update transformer_2d
chang-wenbin Jul 31, 2024
28a62c0
rename facebook_dit
chang-wenbin Aug 1, 2024
884e29a
merge PR
chang-wenbin Aug 5, 2024
15d08b6
merge from develop
chang-wenbin Aug 5, 2024
7d49c49
Fixed the original dynamic image bug
chang-wenbin Aug 5, 2024
b03aa8e
update triton op import paddlemix
chang-wenbin Aug 5, 2024
cb86d17
update dit
chang-wenbin Aug 7, 2024
dc0c45c
update transformer_2d & simplified_facebook_dit
chang-wenbin Aug 7, 2024
42f61bc
update demo & implified_facebook_dit & transformer_2d
chang-wenbin Aug 7, 2024
000dd80
update Inference_Optimize
chang-wenbin Aug 7, 2024
9bb9cde
update demo & simplified_facebook_dit
chang-wenbin Aug 7, 2024
d3de838
update demo
chang-wenbin Aug 7, 2024
400ab19
update demo simplified_facebook_dit transformer_2d
chang-wenbin Aug 7, 2024
bfe8c41
update demo transformer_2d & simplified_facebook_dit
chang-wenbin Aug 7, 2024
8896057
test
chang-wenbin Aug 7, 2024
e9aa47d
add format
chang-wenbin Aug 7, 2024
c8916f7
add format
chang-wenbin Aug 7, 2024
a87f81b
add Argument to the demo
chang-wenbin Aug 8, 2024
0a09bf2
update Argument to the demo
chang-wenbin Aug 8, 2024
10e8c1f
Merge remote-tracking branch 'upstream/develop' into DIT_PaddleMIX_729
chang-wenbin Aug 8, 2024
10953b5
update transformer_2d
chang-wenbin Aug 9, 2024
922d7d0
update DIT_demo
chang-wenbin Aug 19, 2024
c4f8242
Merge branch 'develop' into DIT_PaddleMIX_729
nemonameless Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ppdiffusers/ppdiffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

return model

@classmethod
def custom_modify_weight(cls, state_dict):
pass

@classmethod
def _load_pretrained_model(
cls,
Expand Down Expand Up @@ -1130,6 +1134,7 @@ def _find_mismatched_keys(
error_msgs.append(
f"Error size mismatch, {key_name} receives a shape {loaded_shape}, but the expected shape is {model_shape}."
)
cls.custom_modify_weight(state_dict)
faster_set_state_dict(model_to_load, state_dict)

missing_keys = sorted(list(set(expected_keys) - set(loaded_keys)))
Expand Down
164 changes: 128 additions & 36 deletions ppdiffusers/ppdiffusers/models/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
recompute_use_reentrant,
use_old_recompute,
)
from .zkk_facebook_dit import ZKKFacebookDIT

from .attention import BasicTransformerBlock
from .embeddings import CaptionProjection, PatchEmbed
from .lora import LoRACompatibleConv, LoRACompatibleLinear
Expand Down Expand Up @@ -213,6 +215,8 @@ def __init__(
for d in range(num_layers)
]
)

self.tmp_ZKKFacebookDIT = ZKKFacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim)

# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
Expand Down Expand Up @@ -385,41 +389,44 @@ def forward(
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]])

for block in self.transformer_blocks:
if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute():

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False}
hidden_states = recompute(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
class_labels,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)

# for block in self.transformer_blocks:
# if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute():

# def create_custom_forward(module, return_dict=None):
# def custom_forward(*inputs):
# if return_dict is not None:
# return module(*inputs, return_dict=return_dict)
# else:
# return module(*inputs)

# return custom_forward

# ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False}
# hidden_states = recompute(
# create_custom_forward(block),
# hidden_states,
# attention_mask,
# encoder_hidden_states,
# encoder_attention_mask,
# timestep,
# cross_attention_kwargs,
# class_labels,
# **ckpt_kwargs,
# )
# else:
# hidden_states = block(
# hidden_states,
# attention_mask=attention_mask,
# encoder_hidden_states=encoder_hidden_states,
# encoder_attention_mask=encoder_attention_mask,
# timestep=timestep,
# cross_attention_kwargs=cross_attention_kwargs,
# class_labels=class_labels,
# )


hidden_states = self.tmp_ZKKFacebookDIT(hidden_states, timestep, class_labels)

# 3. Output
if self.is_input_continuous:
if not self.use_linear_projection:
Expand Down Expand Up @@ -473,7 +480,8 @@ def custom_forward(*inputs):
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = paddle.einsum("nhwpqc->nchpwq", hidden_states)
#hidden_states = paddle.einsum("nhwpqc->nchpwq", hidden_states)
hidden_states = hidden_states.transpose([0,5,1,3,2,4])
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
Expand All @@ -482,3 +490,87 @@ def custom_forward(*inputs):
return (output,)

return Transformer2DModelOutput(sample=output)

@classmethod
def custom_modify_weight(cls, state_dict):
for key in list(state_dict.keys()):
if 'attn1.to_q.weight' in key or 'attn1.to_k.weight' in key or 'attn1.to_v.weight' in key:
part = key.split('.')[-2]
layer_id = key.split('.')[1]
qkv_key_w = f'transformer_blocks.{layer_id}.attn1.to_qkv.weight'
if part == 'to_q' and qkv_key_w not in state_dict:
state_dict[qkv_key_w] = state_dict.pop(key)
elif part in ('to_k', 'to_v'):
qkv = state_dict.get(qkv_key_w)
if qkv is not None:
state_dict[qkv_key_w] = paddle.concat([qkv, state_dict.pop(key)], axis=-1)
if 'attn1.to_q.bias' in key or 'attn1.to_k.bias' in key or 'attn1.to_v.bias' in key:
part = key.split('.')[-2]
layer_id = key.split('.')[1]
qkv_key_b = f'transformer_blocks.{layer_id}.attn1.to_qkv.bias'
if part == 'to_q' and qkv_key_b not in state_dict:
state_dict[qkv_key_b] = state_dict.pop(key)
elif part in ('to_k', 'to_v'):
qkv = state_dict.get(qkv_key_b)
if qkv is not None:
state_dict[qkv_key_b] = paddle.concat([qkv, state_dict.pop(key)], axis=-1)

for key in list(state_dict.keys()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

518行以下改成

        map_from_my_dit = {}
        for i in range(28):
            map_from_my_dit[f'tmp_ZKKFacebookDIT.qkv.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_qkv.weight'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.qkv.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_qkv.bias'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.out_proj.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_out.0.weight'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.out_proj.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_out.0.bias'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn1.{i}.weight'] = f'transformer_blocks.{i}.ff.net.0.proj.weight'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn1.{i}.bias'] = f'transformer_blocks.{i}.ff.net.0.proj.bias'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn2.{i}.weight'] = f'transformer_blocks.{i}.ff.net.2.weight'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn2.{i}.bias'] = f'transformer_blocks.{i}.ff.net.2.bias'

            map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs0.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs0.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias'

            map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs1.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs1.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias'

            map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs2.{i}.weight'] = f'transformer_blocks.{i}.norm1.linear.weight'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs2.{i}.bias'] = f'transformer_blocks.{i}.norm1.linear.bias'

            map_from_my_dit[f'tmp_ZKKFacebookDIT.embs.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight'

        for key in map_from_my_dit.keys():
            state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更改!
感谢提供修改意见,辛苦!

name = ""
if 'attn1.to_qkv.weight' in key:
layer_id = (int)(key.split(".")[1])
name = f'tmp_ZKKFacebookDIT.qkv.{layer_id}.weight'.format(layer_id)
if 'attn1.to_qkv.bias' in key:
layer_id = (int)(key.split(".")[1])
name = f'tmp_ZKKFacebookDIT.qkv.{layer_id}.bias'.format(layer_id)

if 'attn1.to_out.0.weight' in key:
layer_id = (int)(key.split(".")[1])
name = f'tmp_ZKKFacebookDIT.out_proj.{layer_id}.weight'.format(layer_id)
if 'attn1.to_out.0.bias' in key:
layer_id = (int)(key.split(".")[1])
name = f'tmp_ZKKFacebookDIT.out_proj.{layer_id}.bias'.format(layer_id)

if 'ff.net.0.proj.weight' in key:
layer_id = (int)(key.split(".")[1])
name = f'tmp_ZKKFacebookDIT.ffn1.{layer_id}.weight'.format(layer_id)
if 'ff.net.0.proj.bias' in key:
layer_id = (int)(key.split(".")[1])
name = f'tmp_ZKKFacebookDIT.ffn1.{layer_id}.bias'.format(layer_id)

if 'ff.net.2.weight' in key:
layer_id = (int)(key.split(".")[1])
name = f'tmp_ZKKFacebookDIT.ffn2.{layer_id}.weight'.format(layer_id)
if 'ff.net.2.bias' in key:
layer_id = (int)(key.split(".")[1])
name = f'tmp_ZKKFacebookDIT.ffn2.{layer_id}.bias'.format(layer_id)


if 'norm1.emb.timestep_embedder.linear_1.weight' in key:
layer_id = (int)(key.split(".")[1])
name = f'tmp_ZKKFacebookDIT.fcs0.{layer_id}.weight'.format(layer_id)
if 'norm1.emb.timestep_embedder.linear_1.bias' in key:
layer_id = (int)(key.split(".")[1])
name = f'tmp_ZKKFacebookDIT.fcs0.{layer_id}.bias'.format(layer_id)


if 'norm1.emb.timestep_embedder.linear_2.weight' in key:
layer_id = (int)(key.split(".")[1])
name = f'tmp_ZKKFacebookDIT.fcs1.{layer_id}.weight'.format(layer_id)
if 'norm1.emb.timestep_embedder.linear_2.bias' in key:
layer_id = (int)(key.split(".")[1])
name = f'tmp_ZKKFacebookDIT.fcs1.{layer_id}.bias'.format(layer_id)


if 'norm1.linear.weight' in key:
layer_id = (int)(key.split(".")[1])
name = f'tmp_ZKKFacebookDIT.fcs2.{layer_id}.weight'.format(layer_id)
if 'norm1.linear.bias' in key:
layer_id = (int)(key.split(".")[1])
name = f'tmp_ZKKFacebookDIT.fcs2.{layer_id}.bias'.format(layer_id)

if 'class_embedder.embedding_table.weight' in key:
layer_id = (int)(key.split(".")[1])
name = f'tmp_ZKKFacebookDIT.embs.{layer_id}.weight'.format(layer_id)

state_dict[name] = paddle.assign(state_dict[key])
118 changes: 118 additions & 0 deletions ppdiffusers/ppdiffusers/models/zkk_facebook_dit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from paddle import nn
import paddle
import paddle.nn.functional as F
import math

class ZKKFacebookDIT(nn.Layer):
def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int):
super().__init__()
self.num_layers = num_layers
self.dtype = "float16"
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.timestep_embedder_in_channels = 256
self.timestep_embedder_time_embed_dim = 1152
self.timestep_embedder_time_embed_dim_out = self.timestep_embedder_time_embed_dim
self.CombinedTimestepLabelEmbeddings_num_embeddings = 1001
self.CombinedTimestepLabelEmbeddings_embedding_dim = 1152

self.fcs0 = nn.LayerList([nn.Linear(self.timestep_embedder_in_channels,
self.timestep_embedder_time_embed_dim) for i in range(self.num_layers)])

self.fcs1 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim,
self.timestep_embedder_time_embed_dim_out) for i in range(self.num_layers)])

self.fcs2 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim,
6 * self.timestep_embedder_time_embed_dim) for i in range(self.num_layers)])

self.embs = nn.LayerList([nn.Embedding(self.CombinedTimestepLabelEmbeddings_embedding_dim,
self.CombinedTimestepLabelEmbeddings_num_embeddings) for i in range(self.num_layers)])


self.qkv = nn.LayerList([nn.Linear(dim, dim * 3) for i in range(self.num_layers)])
self.out_proj = nn.LayerList([nn.Linear(dim, dim) for i in range(self.num_layers)])
self.ffn1 = nn.LayerList([nn.Linear(dim, dim*4) for i in range(self.num_layers)])
self.ffn2 = nn.LayerList([nn.Linear(dim*4, dim) for i in range(self.num_layers)])

@paddle.incubate.jit.inference(enable_new_ir=True,
cache_static_model=True,
exp_enable_use_cutlass=True,
delete_pass_lists=["add_norm_fuse_pass"],
)
def forward(self, hidden_states, timesteps, class_labels):

# below code are copied from PaddleMIX/ppdiffusers/ppdiffusers/models/embeddings.py
num_channels = 256
max_period = 10000
downscale_freq_shift = 1
half_dim = num_channels // 2
exponent = -math.log(max_period) * paddle.arange(start=0, end=half_dim, dtype="float32")
exponent = exponent / (half_dim - downscale_freq_shift)
emb = paddle.exp(exponent)
emb = timesteps[:, None].cast("float32") * emb[None, :]
emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1)
common_emb = emb.cast(self.dtype)

for i in range(self.num_layers): #$$ for?
emb = self.fcs0[i](common_emb)
emb = F.silu(emb)
emb = self.fcs1[i](emb)
emb = emb + self.embs[i](class_labels)
emb = F.silu(emb)
emb = self.fcs2[i](emb)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1)
norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_msa, shift_msa)
q,k,v = self.qkv[i](norm_hidden_states).chunk(3, axis=-1)
b,s,h = q.shape
q = q.reshape([b,s,self.num_attention_heads,self.attention_head_dim])
k = k.reshape([b,s,self.num_attention_heads,self.attention_head_dim])
v = v.reshape([b,s,self.num_attention_heads,self.attention_head_dim])

norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.attention_head_dim**-0.5)
norm_hidden_states = norm_hidden_states.reshape([b,s,self.dim])
norm_hidden_states = self.out_proj[i](norm_hidden_states)

# hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([b,1,self.dim])
# norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp)

hidden_states,norm_hidden_states = paddle.incubate.tt.fused_adaLN_scale_residual(hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp)

norm_hidden_states = self.ffn1[i](norm_hidden_states)
norm_hidden_states = F.gelu(norm_hidden_states, approximate=True)
norm_hidden_states = self.ffn2[i](norm_hidden_states)

hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([b,1,self.dim])

return hidden_states


# tmp = paddle.arange(dtype='float32', end=128)
# tmp = tmp * -9.21034049987793 * 0.007874015718698502
# tmp = paddle.exp(tmp).reshape([1,128])
# timestep = timestep.cast("float32")
# timestep = timestep.reshape([2,1])
# tmp = tmp * timestep
# tmp = paddle.concat([paddle.cos(tmp), paddle.sin(tmp)], axis=-1)
# common_tmp = tmp.cast(self.dtype)
# for i in range(self.num_layers):
# tmp = self.fcs0[i](common_tmp)
# tmp = F.silu(tmp)
# tmp = self.fcs1[i](tmp)
# tmp = tmp + self.embs[i](class_labels)
# tmp = F.silu(tmp)
# tmp = self.fcs2[i](tmp)
# shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = tmp.chunk(6, axis=1)
# norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_msa, shift_msa)
# q,k,v = self.qkv[i](norm_hidden_states).chunk(3, axis=-1)
# orm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.attention_head_dim**-0.5)
# norm_hidden_states = norm_hidden_states.reshape([2,256,1152])
# orm_hidden_states = self.out_proj[i](norm_hidden_states)
# hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([2,1,1152])
# norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp)
# norm_hidden_states = self.ffn1[i](norm_hidden_states)
# norm_hidden_states = F.gelu(norm_hidden_states, approximate=True)
# norm_hidden_states = self.ffn2[i](norm_hidden_states)
# hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([2,1,1152])
# return hidden_states