13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
+ import inspect
16
17
from typing import Callable , List , Optional , Union
17
18
18
19
import paddle
28
29
)
29
30
from ppdiffusers import (
30
31
AutoencoderKL ,
32
+ DDIMScheduler ,
31
33
DiffusionPipeline ,
32
34
LMSDiscreteScheduler ,
33
35
PNDMScheduler ,
@@ -84,7 +86,7 @@ def __init__(
84
86
clip_model : CLIPModel ,
85
87
tokenizer : CLIPTokenizer ,
86
88
unet : UNet2DConditionModel ,
87
- scheduler : Union [PNDMScheduler , LMSDiscreteScheduler ],
89
+ scheduler : Union [PNDMScheduler , LMSDiscreteScheduler , DDIMScheduler ],
88
90
feature_extractor : CLIPFeatureExtractor ,
89
91
):
90
92
super ().__init__ ()
@@ -99,7 +101,12 @@ def __init__(
99
101
)
100
102
101
103
self .normalize = transforms .Normalize (mean = feature_extractor .image_mean , std = feature_extractor .image_std )
102
- self .make_cutouts = MakeCutouts (feature_extractor .size )
104
+ self .cut_out_size = (
105
+ feature_extractor .size
106
+ if isinstance (feature_extractor .size , int )
107
+ else feature_extractor .size ["shortest_edge" ]
108
+ )
109
+ self .make_cutouts = MakeCutouts (self .cut_out_size )
103
110
104
111
set_stop_gradient (self .text_encoder , True )
105
112
set_stop_gradient (self .clip_model , True )
@@ -152,7 +159,7 @@ def cond_fn(
152
159
# predict the noise residual
153
160
noise_pred = self .unet (latent_model_input , timestep , encoder_hidden_states = text_embeddings ).sample
154
161
155
- if isinstance (self .scheduler , PNDMScheduler ):
162
+ if isinstance (self .scheduler , ( PNDMScheduler , DDIMScheduler ) ):
156
163
alpha_prod_t = self .scheduler .alphas_cumprod [timestep ]
157
164
beta_prod_t = 1 - alpha_prod_t
158
165
# compute predicted original sample from predicted noise also called
@@ -174,7 +181,7 @@ def cond_fn(
174
181
if use_cutouts :
175
182
image = self .make_cutouts (image , num_cutouts )
176
183
else :
177
- resize_transform = transforms .Resize (self .feature_extractor . size )
184
+ resize_transform = transforms .Resize (self .cut_out_size )
178
185
image = paddle .stack ([resize_transform (img ) for img in image ], axis = 0 )
179
186
image = self .normalize (image ).astype (latents .dtype )
180
187
@@ -207,11 +214,12 @@ def __call__(
207
214
guidance_scale : Optional [float ] = 7.5 ,
208
215
negative_prompt : Optional [Union [str , List [str ]]] = None ,
209
216
num_images_per_prompt : Optional [int ] = 1 ,
217
+ eta : float = 0.0 ,
210
218
clip_guidance_scale : Optional [float ] = 100 ,
211
219
clip_prompt : Optional [Union [str , List [str ]]] = None ,
212
220
num_cutouts : Optional [int ] = 4 ,
213
221
use_cutouts : Optional [bool ] = True ,
214
- seed : Optional [int ] = None ,
222
+ generator : Optional [paddle . Generator ] = None ,
215
223
latents : Optional [paddle .Tensor ] = None ,
216
224
output_type : Optional [str ] = "pil" ,
217
225
return_dict : bool = True ,
@@ -277,9 +285,9 @@ def __call__(
277
285
text_embeddings_clip = self .clip_model .get_text_features (clip_text_input_ids )
278
286
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip .norm (p = 2 , axis = - 1 , keepdim = True )
279
287
# duplicate text embeddings clip for each generation per prompt
280
- bs_embed , seq_len , _ = text_embeddings .shape
281
- text_embeddings_clip = text_embeddings_clip .tile ([1 , num_images_per_prompt , 1 ])
282
- text_embeddings_clip = text_embeddings_clip .reshape ([bs_embed * num_images_per_prompt , seq_len , - 1 ])
288
+ bs_embed , _ = text_embeddings_clip .shape
289
+ text_embeddings_clip = text_embeddings_clip .tile ([1 , num_images_per_prompt ])
290
+ text_embeddings_clip = text_embeddings_clip .reshape ([bs_embed * num_images_per_prompt , - 1 ])
283
291
284
292
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
285
293
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -334,8 +342,7 @@ def __call__(
334
342
# However this currently doesn't work in `mps`.
335
343
latents_shape = [batch_size * num_images_per_prompt , self .unet .in_channels , height // 8 , width // 8 ]
336
344
if latents is None :
337
- paddle .seed (seed )
338
- latents = paddle .randn (latents_shape , dtype = text_embeddings .dtype )
345
+ latents = paddle .randn (latents_shape , generator = generator , dtype = text_embeddings .dtype )
339
346
else :
340
347
if latents .shape != latents_shape :
341
348
raise ValueError (f"Unexpected latents shape, got { latents .shape } , expected { latents_shape } " )
@@ -350,6 +357,20 @@ def __call__(
350
357
# scale the initial noise by the standard deviation required by the scheduler
351
358
latents = latents * self .scheduler .init_noise_sigma
352
359
360
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
361
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
362
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
363
+ # and should be between [0, 1]
364
+ accepts_eta = "eta" in set (inspect .signature (self .scheduler .step ).parameters .keys ())
365
+ extra_step_kwargs = {}
366
+ if accepts_eta :
367
+ extra_step_kwargs ["eta" ] = eta
368
+
369
+ # check if the scheduler accepts generator
370
+ accepts_generator = "generator" in set (inspect .signature (self .scheduler .step ).parameters .keys ())
371
+ if accepts_generator :
372
+ extra_step_kwargs ["generator" ] = generator
373
+
353
374
for i , t in enumerate (self .progress_bar (timesteps_tensor )):
354
375
# expand the latents if we are doing classifier free guidance
355
376
latent_model_input = paddle .concat ([latents ] * 2 ) if do_classifier_free_guidance else latents
@@ -381,7 +402,7 @@ def __call__(
381
402
)
382
403
383
404
# compute the previous noisy sample x_t -> x_t-1
384
- latents = self .scheduler .step (noise_pred , t , latents ).prev_sample
405
+ latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs ).prev_sample
385
406
386
407
# call the callback, if provided
387
408
if callback is not None and i % callback_steps == 0 :
0 commit comments