Skip to content

Commit a0558b1

Browse files
vishnu-anirudhV Vishnu AnirudhV Vishnu Anirudh
authored
adding more typehints to DDIM scheduler (#456)
* adding more typehints * resolving mypy issues * resolving formatting issue * fixing isort issue Co-authored-by: V Vishnu Anirudh <git.vva@gmail.com> Co-authored-by: V Vishnu Anirudh <vvani@kth.se>
1 parent 06924c6 commit a0558b1

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# and https://github.com/hojonathanho/diffusion
1717

1818
import math
19-
from typing import Optional, Tuple, Union
19+
from typing import List, Optional, Tuple, Union
2020

2121
import numpy as np
2222
import torch
@@ -25,7 +25,7 @@
2525
from .scheduling_utils import SchedulerMixin, SchedulerOutput
2626

2727

28-
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
28+
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta: float = 0.999) -> np.ndarray:
2929
"""
3030
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
3131
(1-beta) over time from t = [0,1].
@@ -43,14 +43,14 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
4343
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
4444
"""
4545

46-
def alpha_bar(time_step):
46+
def calculate_alpha_bar(time_step: float) -> float:
4747
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
4848

49-
betas = []
50-
for i in range(num_diffusion_timesteps):
51-
t1 = i / num_diffusion_timesteps
52-
t2 = (i + 1) / num_diffusion_timesteps
53-
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
49+
betas: List[float] = []
50+
for diffusion_timestep in range(num_diffusion_timesteps):
51+
lower_timestep = diffusion_timestep / num_diffusion_timesteps
52+
upper_timestep = (diffusion_timestep + 1) / num_diffusion_timesteps
53+
betas.append(min(1 - calculate_alpha_bar(upper_timestep) / calculate_alpha_bar(lower_timestep), max_beta))
5454
return np.array(betas, dtype=np.float32)
5555

5656

@@ -96,7 +96,7 @@ def __init__(
9696
tensor_format: str = "pt",
9797
):
9898
if trained_betas is not None:
99-
self.betas = np.asarray(trained_betas)
99+
self.betas: np.ndarray = np.asarray(trained_betas)
100100
if beta_schedule == "linear":
101101
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
102102
elif beta_schedule == "scaled_linear":
@@ -108,8 +108,8 @@ def __init__(
108108
else:
109109
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
110110

111-
self.alphas = 1.0 - self.betas
112-
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
111+
self.alphas: np.ndarray = 1.0 - self.betas
112+
self.alphas_cumprod: np.ndarray = np.cumprod(self.alphas, axis=0)
113113

114114
# At every step in ddim, we are looking into the previous alphas_cumprod
115115
# For the final step, there is no previous alphas_cumprod because we are already at 0
@@ -118,10 +118,10 @@ def __init__(
118118
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
119119

120120
# setable values
121-
self.num_inference_steps = None
122-
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
121+
self.num_inference_steps: int = 0
122+
self.timesteps: np.ndarray = np.arange(0, num_train_timesteps)[::-1].copy()
123123

124-
self.tensor_format = tensor_format
124+
self.tensor_format: str = tensor_format
125125
self.set_format(tensor_format=tensor_format)
126126

127127
def _get_variance(self, timestep, prev_timestep):
@@ -134,7 +134,7 @@ def _get_variance(self, timestep, prev_timestep):
134134

135135
return variance
136136

137-
def set_timesteps(self, num_inference_steps: int, offset: int = 0):
137+
def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> None:
138138
"""
139139
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
140140

0 commit comments

Comments
 (0)