-
Notifications
You must be signed in to change notification settings - Fork 218
Re-network the DIT, fix some parameters, and simplify the model networking code #632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
nemonameless
merged 28 commits into
PaddlePaddle:develop
from
chang-wenbin:DIT_PaddleMIX_729
Aug 28, 2024
Merged
Changes from 3 commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
59f23a0
modified the dit
chang-wenbin 5fee64b
add zkk_facebook
chang-wenbin f653a66
update zkk_facebook_dit.py
chang-wenbin 3b29d9d
update transformer_2d
chang-wenbin a88caea
update dit optimize
chang-wenbin 54eeec2
update transformer_2d
chang-wenbin 28a62c0
rename facebook_dit
chang-wenbin 884e29a
merge PR
chang-wenbin 15d08b6
merge from develop
chang-wenbin 7d49c49
Fixed the original dynamic image bug
chang-wenbin b03aa8e
update triton op import paddlemix
chang-wenbin cb86d17
update dit
chang-wenbin dc0c45c
update transformer_2d & simplified_facebook_dit
chang-wenbin 42f61bc
update demo & implified_facebook_dit & transformer_2d
chang-wenbin 000dd80
update Inference_Optimize
chang-wenbin 9bb9cde
update demo & simplified_facebook_dit
chang-wenbin d3de838
update demo
chang-wenbin 400ab19
update demo simplified_facebook_dit transformer_2d
chang-wenbin bfe8c41
update demo transformer_2d & simplified_facebook_dit
chang-wenbin 8896057
test
chang-wenbin e9aa47d
add format
chang-wenbin c8916f7
add format
chang-wenbin a87f81b
add Argument to the demo
chang-wenbin 0a09bf2
update Argument to the demo
chang-wenbin 10e8c1f
Merge remote-tracking branch 'upstream/develop' into DIT_PaddleMIX_729
chang-wenbin 10953b5
update transformer_2d
chang-wenbin 922d7d0
update DIT_demo
chang-wenbin c4f8242
Merge branch 'develop' into DIT_PaddleMIX_729
nemonameless File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from paddle import nn | ||
import paddle | ||
import paddle.nn.functional as F | ||
import math | ||
|
||
class ZKKFacebookDIT(nn.Layer): | ||
def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): | ||
super().__init__() | ||
self.num_layers = num_layers | ||
self.dtype = "float16" | ||
self.dim = dim | ||
self.num_attention_heads = num_attention_heads | ||
self.attention_head_dim = attention_head_dim | ||
self.timestep_embedder_in_channels = 256 | ||
self.timestep_embedder_time_embed_dim = 1152 | ||
self.timestep_embedder_time_embed_dim_out = self.timestep_embedder_time_embed_dim | ||
self.CombinedTimestepLabelEmbeddings_num_embeddings = 1001 | ||
self.CombinedTimestepLabelEmbeddings_embedding_dim = 1152 | ||
|
||
self.fcs0 = nn.LayerList([nn.Linear(self.timestep_embedder_in_channels, | ||
self.timestep_embedder_time_embed_dim) for i in range(self.num_layers)]) | ||
|
||
self.fcs1 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim, | ||
self.timestep_embedder_time_embed_dim_out) for i in range(self.num_layers)]) | ||
|
||
self.fcs2 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim, | ||
6 * self.timestep_embedder_time_embed_dim) for i in range(self.num_layers)]) | ||
|
||
self.embs = nn.LayerList([nn.Embedding(self.CombinedTimestepLabelEmbeddings_embedding_dim, | ||
self.CombinedTimestepLabelEmbeddings_num_embeddings) for i in range(self.num_layers)]) | ||
|
||
|
||
self.qkv = nn.LayerList([nn.Linear(dim, dim * 3) for i in range(self.num_layers)]) | ||
self.out_proj = nn.LayerList([nn.Linear(dim, dim) for i in range(self.num_layers)]) | ||
self.ffn1 = nn.LayerList([nn.Linear(dim, dim*4) for i in range(self.num_layers)]) | ||
self.ffn2 = nn.LayerList([nn.Linear(dim*4, dim) for i in range(self.num_layers)]) | ||
|
||
@paddle.incubate.jit.inference(enable_new_ir=True, | ||
cache_static_model=True, | ||
exp_enable_use_cutlass=True, | ||
delete_pass_lists=["add_norm_fuse_pass"], | ||
) | ||
def forward(self, hidden_states, timesteps, class_labels): | ||
|
||
# below code are copied from PaddleMIX/ppdiffusers/ppdiffusers/models/embeddings.py | ||
num_channels = 256 | ||
max_period = 10000 | ||
downscale_freq_shift = 1 | ||
half_dim = num_channels // 2 | ||
exponent = -math.log(max_period) * paddle.arange(start=0, end=half_dim, dtype="float32") | ||
exponent = exponent / (half_dim - downscale_freq_shift) | ||
emb = paddle.exp(exponent) | ||
emb = timesteps[:, None].cast("float32") * emb[None, :] | ||
emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1) | ||
common_emb = emb.cast(self.dtype) | ||
|
||
for i in range(self.num_layers): #$$ for? | ||
emb = self.fcs0[i](common_emb) | ||
emb = F.silu(emb) | ||
emb = self.fcs1[i](emb) | ||
emb = emb + self.embs[i](class_labels) | ||
emb = F.silu(emb) | ||
emb = self.fcs2[i](emb) | ||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) | ||
norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_msa, shift_msa) | ||
q,k,v = self.qkv[i](norm_hidden_states).chunk(3, axis=-1) | ||
b,s,h = q.shape | ||
q = q.reshape([b,s,self.num_attention_heads,self.attention_head_dim]) | ||
k = k.reshape([b,s,self.num_attention_heads,self.attention_head_dim]) | ||
v = v.reshape([b,s,self.num_attention_heads,self.attention_head_dim]) | ||
|
||
norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.attention_head_dim**-0.5) | ||
norm_hidden_states = norm_hidden_states.reshape([b,s,self.dim]) | ||
norm_hidden_states = self.out_proj[i](norm_hidden_states) | ||
|
||
# hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([b,1,self.dim]) | ||
# norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp) | ||
|
||
hidden_states,norm_hidden_states = paddle.incubate.tt.fused_adaLN_scale_residual(hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp) | ||
|
||
norm_hidden_states = self.ffn1[i](norm_hidden_states) | ||
norm_hidden_states = F.gelu(norm_hidden_states, approximate=True) | ||
norm_hidden_states = self.ffn2[i](norm_hidden_states) | ||
|
||
hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([b,1,self.dim]) | ||
|
||
return hidden_states | ||
|
||
|
||
# tmp = paddle.arange(dtype='float32', end=128) | ||
# tmp = tmp * -9.21034049987793 * 0.007874015718698502 | ||
# tmp = paddle.exp(tmp).reshape([1,128]) | ||
# timestep = timestep.cast("float32") | ||
# timestep = timestep.reshape([2,1]) | ||
# tmp = tmp * timestep | ||
# tmp = paddle.concat([paddle.cos(tmp), paddle.sin(tmp)], axis=-1) | ||
# common_tmp = tmp.cast(self.dtype) | ||
# for i in range(self.num_layers): | ||
# tmp = self.fcs0[i](common_tmp) | ||
# tmp = F.silu(tmp) | ||
# tmp = self.fcs1[i](tmp) | ||
# tmp = tmp + self.embs[i](class_labels) | ||
# tmp = F.silu(tmp) | ||
# tmp = self.fcs2[i](tmp) | ||
# shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = tmp.chunk(6, axis=1) | ||
# norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_msa, shift_msa) | ||
# q,k,v = self.qkv[i](norm_hidden_states).chunk(3, axis=-1) | ||
# orm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.attention_head_dim**-0.5) | ||
# norm_hidden_states = norm_hidden_states.reshape([2,256,1152]) | ||
# orm_hidden_states = self.out_proj[i](norm_hidden_states) | ||
# hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([2,1,1152]) | ||
# norm_hidden_states = paddle.incubate.tt.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp) | ||
# norm_hidden_states = self.ffn1[i](norm_hidden_states) | ||
# norm_hidden_states = F.gelu(norm_hidden_states, approximate=True) | ||
# norm_hidden_states = self.ffn2[i](norm_hidden_states) | ||
# hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([2,1,1152]) | ||
# return hidden_states | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
518行以下改成
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更改!
感谢提供修改意见,辛苦!