-
Notifications
You must be signed in to change notification settings - Fork 218
Re-network the DIT, fix some parameters, and simplify the model networking code #632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
59f23a0
5fee64b
f653a66
3b29d9d
a88caea
54eeec2
28a62c0
884e29a
15d08b6
7d49c49
b03aa8e
cb86d17
dc0c45c
42f61bc
000dd80
9bb9cde
d3de838
400ab19
bfe8c41
8896057
e9aa47d
c8916f7
a87f81b
0a09bf2
10e8c1f
10953b5
922d7d0
c4f8242
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,19 +12,40 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import paddle | ||
from paddlenlp.trainer import set_seed | ||
|
||
from ppdiffusers import DDIMScheduler, DiTPipeline | ||
|
||
dtype = paddle.float32 | ||
os.environ["Inference_Optimize"] = "False" | ||
|
||
dtype = paddle.float16 | ||
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype) | ||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | ||
set_seed(42) | ||
|
||
words = ["golden retriever"] # class_ids [207] | ||
class_ids = pipe.get_label_ids(words) | ||
|
||
# warmup | ||
for i in range(5): | ||
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里只是为了测benchmark,实际用户并不需要warmpup。看下是否增加benchmark开关。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已更改,添加benchmark & inference_optimize 的相关开关! |
||
|
||
|
||
import datetime | ||
import time | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. import移动到前面 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已更改! |
||
repeat_times = 10 | ||
paddle.device.synchronize() | ||
starttime = datetime.datetime.now() | ||
|
||
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] | ||
for i in range(repeat_times): | ||
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上,benchmark才需要,用户使用不需要 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已更改! |
||
|
||
paddle.device.synchronize() | ||
endtime = datetime.datetime.now() | ||
duringtime = endtime - starttime | ||
time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 | ||
|
||
print("The ave end to end time : ", time_ms / repeat_times, "ms") | ||
image.save("class_conditional_image_generation-dit-result.png") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from paddle import nn | ||
import paddle | ||
import paddle.nn.functional as F | ||
import math | ||
|
||
class SimplifiedFacebookDIT(nn.Layer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 必须一定要简化这个模块吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
手工优化需要 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 手工优化需要对原动态图模型组网 做高性能精简重组,这一模块还将transformer循环中的冗余计算部分提出,减少了部分计算量。 |
||
def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): | ||
super().__init__() | ||
self.num_layers = num_layers | ||
self.dim = dim | ||
self.heads_num = num_attention_heads | ||
self.head_dim = attention_head_dim | ||
self.timestep_embedder_in_channels = 256 | ||
self.timestep_embedder_time_embed_dim = 1152 | ||
self.timestep_embedder_time_embed_dim_out = self.timestep_embedder_time_embed_dim | ||
self.LabelEmbedding_num_classes = 1001 | ||
self.LabelEmbedding_num_hidden_size = 1152 | ||
|
||
self.fcs0 = nn.LayerList([nn.Linear(self.timestep_embedder_in_channels, | ||
self.timestep_embedder_time_embed_dim) for i in range(num_layers)]) | ||
|
||
self.fcs1 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim, | ||
self.timestep_embedder_time_embed_dim_out) for i in range(num_layers)]) | ||
|
||
self.fcs2 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim, | ||
6 * self.timestep_embedder_time_embed_dim) for i in range(num_layers)]) | ||
|
||
self.embs = nn.LayerList([nn.Embedding(self.LabelEmbedding_num_classes, | ||
self.LabelEmbedding_num_hidden_size) for i in range(num_layers)]) | ||
|
||
|
||
self.q = nn.LayerList([nn.Linear(dim, dim ) for i in range(num_layers)]) | ||
self.k = nn.LayerList([nn.Linear(dim, dim ) for i in range(num_layers)]) | ||
self.v = nn.LayerList([nn.Linear(dim, dim ) for i in range(num_layers)]) | ||
self.out_proj = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)]) | ||
self.ffn1 = nn.LayerList([nn.Linear(dim, dim*4) for i in range(num_layers)]) | ||
self.ffn2 = nn.LayerList([nn.Linear(dim*4, dim) for i in range(num_layers)]) | ||
|
||
def forward(self, hidden_states, timesteps, class_labels): | ||
|
||
# below code are copied from PaddleMIX/ppdiffusers/ppdiffusers/models/embeddings.py | ||
num_channels = 256 | ||
max_period = 10000 | ||
downscale_freq_shift = 1 | ||
half_dim = num_channels // 2 | ||
exponent = -math.log(max_period) * paddle.arange(start=0, end=half_dim, dtype="float32") | ||
exponent = exponent / (half_dim - downscale_freq_shift) | ||
emb = paddle.exp(exponent) | ||
emb = timesteps[:, None].cast("float32") * emb[None, :] | ||
emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1) | ||
common_emb = emb.cast(hidden_states.dtype) | ||
|
||
for i in range(self.num_layers): | ||
emb = self.fcs0[i](common_emb) | ||
emb = F.silu(emb) | ||
emb = self.fcs1[i](emb) | ||
emb = emb + self.embs[i](class_labels) | ||
emb = F.silu(emb) | ||
emb = self.fcs2[i](emb) | ||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) | ||
import paddlemix | ||
norm_hidden_states =paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_msa, shift_msa) | ||
q = self.q[i](norm_hidden_states).reshape([0,0,self.heads_num,self.head_dim]) | ||
k = self.k[i](norm_hidden_states).reshape([0,0,self.heads_num,self.head_dim]) | ||
v = self.v[i](norm_hidden_states).reshape([0,0,self.heads_num,self.head_dim]) | ||
|
||
norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.head_dim**-0.5) | ||
norm_hidden_states = norm_hidden_states.reshape([norm_hidden_states.shape[0],norm_hidden_states.shape[1],self.dim]) | ||
norm_hidden_states = self.out_proj[i](norm_hidden_states) | ||
|
||
# hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([b,1,self.dim]) | ||
# norm_hidden_states =paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp) | ||
hidden_states,norm_hidden_states =paddlemix.triton_ops.fused_adaLN_scale_residual(hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp) | ||
|
||
norm_hidden_states = self.ffn1[i](norm_hidden_states) | ||
norm_hidden_states = F.gelu(norm_hidden_states, approximate=True) | ||
norm_hidden_states = self.ffn2[i](norm_hidden_states) | ||
|
||
hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([norm_hidden_states.shape[0],1,self.dim]) | ||
|
||
return hidden_states | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,11 +28,15 @@ | |
recompute_use_reentrant, | ||
use_old_recompute, | ||
) | ||
from .simplified_facebook_dit import SimplifiedFacebookDIT | ||
|
||
from .attention import BasicTransformerBlock | ||
from .embeddings import CaptionProjection, PatchEmbed | ||
from .lora import LoRACompatibleConv, LoRACompatibleLinear | ||
from .modeling_utils import ModelMixin | ||
from .normalization import AdaLayerNormSingle | ||
import os | ||
|
||
|
||
|
||
@dataclass | ||
|
@@ -114,6 +118,8 @@ def __init__( | |
self.inner_dim = inner_dim = num_attention_heads * attention_head_dim | ||
self.data_format = data_format | ||
|
||
self.Inference_Optimize = os.getenv('Inference_Optimize') == "True" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. self.inference_optimize ,遵守命名规范 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已更改! |
||
|
||
conv_cls = nn.Conv2D if USE_PEFT_BACKEND else LoRACompatibleConv | ||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear | ||
|
||
|
@@ -213,6 +219,14 @@ def __init__( | |
for d in range(num_layers) | ||
] | ||
) | ||
if self.Inference_Optimize: | ||
self.simplified_facebookDIT = SimplifiedFacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) | ||
self.simplified_facebookDIT = paddle.incubate.jit.inference(self.simplified_facebookDIT, | ||
enable_new_ir=True, | ||
cache_static_model=False, | ||
exp_enable_use_cutlass=True, | ||
delete_pass_lists=["add_norm_fuse_pass"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 遵守代码规范,一行不会要超过80字符 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已使用pre-commit调整! |
||
) | ||
|
||
# 4. Define output layers | ||
self.out_channels = in_channels if out_channels is None else out_channels | ||
|
@@ -250,6 +264,7 @@ def __init__( | |
self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim) | ||
|
||
self.gradient_checkpointing = False | ||
|
||
|
||
def _set_gradient_checkpointing(self, module, value=False): | ||
if hasattr(module, "gradient_checkpointing"): | ||
|
@@ -384,41 +399,44 @@ def forward( | |
batch_size = hidden_states.shape[0] | ||
encoder_hidden_states = self.caption_projection(encoder_hidden_states) | ||
encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]]) | ||
|
||
for block in self.transformer_blocks: | ||
if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): | ||
|
||
def create_custom_forward(module, return_dict=None): | ||
def custom_forward(*inputs): | ||
if return_dict is not None: | ||
return module(*inputs, return_dict=return_dict) | ||
else: | ||
return module(*inputs) | ||
|
||
return custom_forward | ||
|
||
ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} | ||
hidden_states = recompute( | ||
create_custom_forward(block), | ||
hidden_states, | ||
attention_mask, | ||
encoder_hidden_states, | ||
encoder_attention_mask, | ||
timestep, | ||
cross_attention_kwargs, | ||
class_labels, | ||
**ckpt_kwargs, | ||
) | ||
else: | ||
hidden_states = block( | ||
hidden_states, | ||
attention_mask=attention_mask, | ||
encoder_hidden_states=encoder_hidden_states, | ||
encoder_attention_mask=encoder_attention_mask, | ||
timestep=timestep, | ||
cross_attention_kwargs=cross_attention_kwargs, | ||
class_labels=class_labels, | ||
) | ||
|
||
if self.Inference_Optimize: | ||
hidden_states =self.simplified_facebookDIT(hidden_states, timestep, class_labels) | ||
else: | ||
for block in self.transformer_blocks: | ||
if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): | ||
|
||
def create_custom_forward(module, return_dict=None): | ||
def custom_forward(*inputs): | ||
if return_dict is not None: | ||
return module(*inputs, return_dict=return_dict) | ||
else: | ||
return module(*inputs) | ||
|
||
return custom_forward | ||
|
||
ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} | ||
hidden_states = recompute( | ||
create_custom_forward(block), | ||
hidden_states, | ||
attention_mask, | ||
encoder_hidden_states, | ||
encoder_attention_mask, | ||
timestep, | ||
cross_attention_kwargs, | ||
class_labels, | ||
**ckpt_kwargs, | ||
) | ||
else: | ||
hidden_states = block( | ||
hidden_states, | ||
attention_mask=attention_mask, | ||
encoder_hidden_states=encoder_hidden_states, | ||
encoder_attention_mask=encoder_attention_mask, | ||
timestep=timestep, | ||
cross_attention_kwargs=cross_attention_kwargs, | ||
class_labels=class_labels, | ||
) | ||
|
||
# 3. Output | ||
if self.is_input_continuous: | ||
|
@@ -482,3 +500,34 @@ def custom_forward(*inputs): | |
return (output,) | ||
|
||
return Transformer2DModelOutput(sample=output) | ||
|
||
@classmethod | ||
def custom_modify_weight(cls, state_dict): | ||
if os.getenv('Inference_Optimize') != "True": | ||
return | ||
map_from_my_dit = {} | ||
for i in range(28): | ||
map_from_my_dit[f'simplified_facebookDIT.q.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_q.weight' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 尽量减少代码的拷贝,例如公共的命名前缀应该抽出来,避免后续修改 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
已更改,折叠了部分命名代码! |
||
map_from_my_dit[f'simplified_facebookDIT.k.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_k.weight' | ||
map_from_my_dit[f'simplified_facebookDIT.v.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_v.weight' | ||
map_from_my_dit[f'simplified_facebookDIT.q.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_q.bias' | ||
map_from_my_dit[f'simplified_facebookDIT.k.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_k.bias' | ||
map_from_my_dit[f'simplified_facebookDIT.v.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_v.bias' | ||
map_from_my_dit[f'simplified_facebookDIT.out_proj.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_out.0.weight' | ||
map_from_my_dit[f'simplified_facebookDIT.out_proj.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_out.0.bias' | ||
map_from_my_dit[f'simplified_facebookDIT.ffn1.{i}.weight'] = f'transformer_blocks.{i}.ff.net.0.proj.weight' | ||
map_from_my_dit[f'simplified_facebookDIT.ffn1.{i}.bias'] = f'transformer_blocks.{i}.ff.net.0.proj.bias' | ||
map_from_my_dit[f'simplified_facebookDIT.ffn2.{i}.weight'] = f'transformer_blocks.{i}.ff.net.2.weight' | ||
map_from_my_dit[f'simplified_facebookDIT.ffn2.{i}.bias'] = f'transformer_blocks.{i}.ff.net.2.bias' | ||
|
||
map_from_my_dit[f'simplified_facebookDIT.fcs0.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight' | ||
map_from_my_dit[f'simplified_facebookDIT.fcs0.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias' | ||
map_from_my_dit[f'simplified_facebookDIT.fcs1.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight' | ||
map_from_my_dit[f'simplified_facebookDIT.fcs1.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias' | ||
map_from_my_dit[f'simplified_facebookDIT.fcs2.{i}.weight'] = f'transformer_blocks.{i}.norm1.linear.weight' | ||
map_from_my_dit[f'simplified_facebookDIT.fcs2.{i}.bias'] = f'transformer_blocks.{i}.norm1.linear.bias' | ||
|
||
map_from_my_dit[f'simplified_facebookDIT.embs.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight' | ||
|
||
for key in map_from_my_dit.keys(): | ||
state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
环境变量全都大写吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更改!
感谢提供修改意见,辛苦!