19
19
from paddle import nn
20
20
21
21
from ..utils import USE_PEFT_BACKEND
22
- from .activations import get_activation , FP32SiLU
22
+ from .activations import FP32SiLU , get_activation
23
23
from .lora import LoRACompatibleLinear
24
24
25
25
@@ -52,12 +52,11 @@ def get_timestep_embedding(
52
52
# scale embeddings
53
53
emb = scale * emb
54
54
55
- # concat sine and cosine embeddings
56
- emb = paddle .concat ([paddle .sin (emb ), paddle .cos (emb )], axis = - 1 )
57
-
58
55
# flip sine and cosine embeddings
59
56
if flip_sin_to_cos :
60
- emb = paddle .concat ([emb [:, half_dim :], emb [:, :half_dim ]], axis = - 1 )
57
+ emb = paddle .concat ([paddle .cos (emb ), paddle .sin (emb )], axis = - 1 )
58
+ else :
59
+ emb = paddle .concat ([paddle .sin (emb ), paddle .cos (emb )], axis = - 1 )
61
60
62
61
# zero pad
63
62
if embedding_dim % 2 == 1 :
@@ -136,7 +135,7 @@ def __init__(
136
135
interpolation_scale = 1 ,
137
136
add_pos_embed = True ,
138
137
data_format = "NCHW" ,
139
- pos_embed_max_size = None , # For SD3 cropping
138
+ pos_embed_max_size = None , # For SD3 cropping
140
139
):
141
140
super ().__init__ ()
142
141
@@ -147,7 +146,12 @@ def __init__(
147
146
self .data_format = data_format
148
147
149
148
self .proj = nn .Conv2D (
150
- in_channels , embed_dim , kernel_size = (patch_size , patch_size ), stride = patch_size , bias_attr = bias , data_format = data_format ,
149
+ in_channels ,
150
+ embed_dim ,
151
+ kernel_size = (patch_size , patch_size ),
152
+ stride = patch_size ,
153
+ bias_attr = bias ,
154
+ data_format = data_format ,
151
155
)
152
156
if layer_norm :
153
157
norm_elementwise_affine_kwargs = dict (weight_attr = False , bias_attr = False )
@@ -178,6 +182,7 @@ def __init__(
178
182
self .register_buffer (
179
183
"pos_embed" , paddle .to_tensor (pos_embed ).cast ("float32" ).unsqueeze (0 ), persistable = persistent
180
184
)
185
+
181
186
def cropped_pos_embed (self , height , width ):
182
187
"""Crops positional embeddings for SD3 compatibility."""
183
188
if self .pos_embed_max_size is None :
@@ -215,7 +220,7 @@ def forward(self, latent):
215
220
if self .data_format == "NCHW" :
216
221
latent = latent .flatten (2 ).transpose ([0 , 2 , 1 ]) # BCHW -> BNC
217
222
else :
218
- latent = latent .flatten (1 , 2 ) # BHWC -> BNC
223
+ latent = latent .flatten (1 , 2 ) # BHWC -> BNC
219
224
if self .layer_norm :
220
225
latent = self .norm (latent )
221
226
@@ -521,7 +526,6 @@ def forward(self, image_embeds: paddle.Tensor):
521
526
return image_embeds
522
527
523
528
524
-
525
529
class CombinedTimestepTextProjEmbeddings (nn .Layer ):
526
530
def __init__ (self , embedding_dim , pooled_projection_dim ):
527
531
super ().__init__ ()
@@ -532,14 +536,15 @@ def __init__(self, embedding_dim, pooled_projection_dim):
532
536
533
537
def forward (self , timestep , pooled_projection ):
534
538
timesteps_proj = self .time_proj (timestep )
535
- timesteps_emb = self .timestep_embedder (timesteps_proj .to (dtype = pooled_projection .dtype )) # (N, D)
539
+ timesteps_emb = self .timestep_embedder (timesteps_proj .cast (dtype = pooled_projection .dtype )) # (N, D)
536
540
537
541
pooled_projections = self .text_embedder (pooled_projection )
538
542
539
543
conditioning = timesteps_emb + pooled_projections
540
544
541
545
return conditioning
542
546
547
+
543
548
class CombinedTimestepLabelEmbeddings (nn .Layer ):
544
549
def __init__ (self , num_classes , embedding_dim , class_dropout_prob = 0.1 ):
545
550
super ().__init__ ()
@@ -906,4 +911,4 @@ def forward(self, caption):
906
911
hidden_states = self .linear_1 (caption )
907
912
hidden_states = self .act_1 (hidden_states )
908
913
hidden_states = self .linear_2 (hidden_states )
909
- return hidden_states
914
+ return hidden_states
0 commit comments