Skip to content

Commit 489e450

Browse files
authored
减少重复拷贝,修复BUG (#699)
避免get_timestep_embedding 方法中的重复拷贝。修复requires_grad_and_without_random方法没有初始化值的问题。
1 parent ad31600 commit 489e450

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

ppdiffusers/ppdiffusers/models/embeddings.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from paddle import nn
2020

2121
from ..utils import USE_PEFT_BACKEND
22-
from .activations import get_activation, FP32SiLU
22+
from .activations import FP32SiLU, get_activation
2323
from .lora import LoRACompatibleLinear
2424

2525

@@ -52,12 +52,11 @@ def get_timestep_embedding(
5252
# scale embeddings
5353
emb = scale * emb
5454

55-
# concat sine and cosine embeddings
56-
emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=-1)
57-
5855
# flip sine and cosine embeddings
5956
if flip_sin_to_cos:
60-
emb = paddle.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1)
57+
emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1)
58+
else:
59+
emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=-1)
6160

6261
# zero pad
6362
if embedding_dim % 2 == 1:
@@ -136,7 +135,7 @@ def __init__(
136135
interpolation_scale=1,
137136
add_pos_embed=True,
138137
data_format="NCHW",
139-
pos_embed_max_size=None, # For SD3 cropping
138+
pos_embed_max_size=None, # For SD3 cropping
140139
):
141140
super().__init__()
142141

@@ -147,7 +146,12 @@ def __init__(
147146
self.data_format = data_format
148147

149148
self.proj = nn.Conv2D(
150-
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias_attr=bias, data_format=data_format,
149+
in_channels,
150+
embed_dim,
151+
kernel_size=(patch_size, patch_size),
152+
stride=patch_size,
153+
bias_attr=bias,
154+
data_format=data_format,
151155
)
152156
if layer_norm:
153157
norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False)
@@ -178,6 +182,7 @@ def __init__(
178182
self.register_buffer(
179183
"pos_embed", paddle.to_tensor(pos_embed).cast("float32").unsqueeze(0), persistable=persistent
180184
)
185+
181186
def cropped_pos_embed(self, height, width):
182187
"""Crops positional embeddings for SD3 compatibility."""
183188
if self.pos_embed_max_size is None:
@@ -215,7 +220,7 @@ def forward(self, latent):
215220
if self.data_format == "NCHW":
216221
latent = latent.flatten(2).transpose([0, 2, 1]) # BCHW -> BNC
217222
else:
218-
latent = latent.flatten(1, 2) # BHWC -> BNC
223+
latent = latent.flatten(1, 2) # BHWC -> BNC
219224
if self.layer_norm:
220225
latent = self.norm(latent)
221226

@@ -521,7 +526,6 @@ def forward(self, image_embeds: paddle.Tensor):
521526
return image_embeds
522527

523528

524-
525529
class CombinedTimestepTextProjEmbeddings(nn.Layer):
526530
def __init__(self, embedding_dim, pooled_projection_dim):
527531
super().__init__()
@@ -532,14 +536,15 @@ def __init__(self, embedding_dim, pooled_projection_dim):
532536

533537
def forward(self, timestep, pooled_projection):
534538
timesteps_proj = self.time_proj(timestep)
535-
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
539+
timesteps_emb = self.timestep_embedder(timesteps_proj.cast(dtype=pooled_projection.dtype)) # (N, D)
536540

537541
pooled_projections = self.text_embedder(pooled_projection)
538542

539543
conditioning = timesteps_emb + pooled_projections
540544

541545
return conditioning
542546

547+
543548
class CombinedTimestepLabelEmbeddings(nn.Layer):
544549
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
545550
super().__init__()
@@ -906,4 +911,4 @@ def forward(self, caption):
906911
hidden_states = self.linear_1(caption)
907912
hidden_states = self.act_1(hidden_states)
908913
hidden_states = self.linear_2(hidden_states)
909-
return hidden_states
914+
return hidden_states

ppdiffusers/ppdiffusers/patches/paddle_patch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def scaled_dot_product_attention_(
429429
# (2) FLAG_USE_CUTLASS_V2 in yes, y, true, t, 1, use cutlass v2
430430
use_cutlass_v2 = attn_mask is not None or str2bool(os.getenv("FLAG_USE_CUTLASS_V2", "no"))
431431
if not use_cutlass_v2:
432-
with requires_grad_and_without_random(query, key, value):
432+
with requires_grad_and_without_random(query, key, value, stop_gradient=False):
433433
output = memory_efficient_attention(
434434
query,
435435
key,

0 commit comments

Comments
 (0)