@@ -72,13 +72,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
72
72
@register_to_config
73
73
def __init__ (
74
74
self ,
75
- num_train_timesteps = 1000 ,
76
- beta_start = 0.0001 ,
77
- beta_end = 0.02 ,
78
- beta_schedule = "linear" ,
75
+ num_train_timesteps : int = 1000 ,
76
+ beta_start : float = 0.0001 ,
77
+ beta_end : float = 0.02 ,
78
+ beta_schedule : str = "linear" ,
79
79
trained_betas = None ,
80
- tensor_format = "pt" ,
81
- skip_prk_steps = False ,
80
+ tensor_format : str = "pt" ,
81
+ skip_prk_steps : bool = False ,
82
82
):
83
83
if trained_betas is not None :
84
84
self .betas = np .asarray (trained_betas )
@@ -120,7 +120,7 @@ def __init__(
120
120
self .tensor_format = tensor_format
121
121
self .set_format (tensor_format = tensor_format )
122
122
123
- def set_timesteps (self , num_inference_steps , offset = 0 ) :
123
+ def set_timesteps (self , num_inference_steps : int , offset : int = 0 ) -> torch . FloatTensor :
124
124
self .num_inference_steps = num_inference_steps
125
125
self ._timesteps = list (
126
126
range (0 , self .config .num_train_timesteps , self .config .num_train_timesteps // num_inference_steps )
@@ -287,7 +287,13 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
287
287
288
288
return prev_sample
289
289
290
- def add_noise (self , original_samples , noise , timesteps ):
290
+ def add_noise (
291
+ self ,
292
+ original_samples : Union [torch .FloatTensor , np .ndarray ],
293
+ noise : Union [torch .FloatTensor , np .ndarray ],
294
+ timesteps : Union [torch .IntTensor , np .ndarray ],
295
+ ) -> torch .Tensor :
296
+
291
297
sqrt_alpha_prod = self .alphas_cumprod [timesteps ] ** 0.5
292
298
sqrt_alpha_prod = self .match_shape (sqrt_alpha_prod , original_samples )
293
299
sqrt_one_minus_alpha_prod = (1 - self .alphas_cumprod [timesteps ]) ** 0.5
0 commit comments