Skip to content

Commit 50a6088

Browse files
dasparthonatolambert
authored andcommitted
[Type hint] PNDM schedulers (#335)
* [Type hint] PNDM Schedulers * ran make style * updated timesteps type hint * apply suggestions from code review * ran make style * removed unused import
1 parent bb1e5f6 commit 50a6088

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
7272
@register_to_config
7373
def __init__(
7474
self,
75-
num_train_timesteps=1000,
76-
beta_start=0.0001,
77-
beta_end=0.02,
78-
beta_schedule="linear",
75+
num_train_timesteps: int = 1000,
76+
beta_start: float = 0.0001,
77+
beta_end: float = 0.02,
78+
beta_schedule: str = "linear",
7979
trained_betas=None,
80-
tensor_format="pt",
81-
skip_prk_steps=False,
80+
tensor_format: str = "pt",
81+
skip_prk_steps: bool = False,
8282
):
8383
if trained_betas is not None:
8484
self.betas = np.asarray(trained_betas)
@@ -120,7 +120,7 @@ def __init__(
120120
self.tensor_format = tensor_format
121121
self.set_format(tensor_format=tensor_format)
122122

123-
def set_timesteps(self, num_inference_steps, offset=0):
123+
def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
124124
self.num_inference_steps = num_inference_steps
125125
self._timesteps = list(
126126
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
@@ -287,7 +287,13 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
287287

288288
return prev_sample
289289

290-
def add_noise(self, original_samples, noise, timesteps):
290+
def add_noise(
291+
self,
292+
original_samples: Union[torch.FloatTensor, np.ndarray],
293+
noise: Union[torch.FloatTensor, np.ndarray],
294+
timesteps: Union[torch.IntTensor, np.ndarray],
295+
) -> torch.Tensor:
296+
291297
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
292298
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
293299
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5

0 commit comments

Comments
 (0)