Skip to content

Commit b598157

Browse files
committed
add docstrings
1 parent 309d206 commit b598157

File tree

7 files changed

+226
-4
lines changed

7 files changed

+226
-4
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,14 @@ def _get_variance(self, timestep, prev_timestep):
131131
return variance
132132

133133
def set_timesteps(self, num_inference_steps: int, offset: int = 0):
134+
"""
135+
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
136+
137+
Args:
138+
num_inference_steps (`int`):
139+
the number of diffusion steps used when generating samples with a pre-trained model.
140+
offset (`int`): TODO
141+
"""
134142
self.num_inference_steps = num_inference_steps
135143
self.timesteps = np.arange(
136144
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
@@ -148,7 +156,24 @@ def step(
148156
generator=None,
149157
return_dict: bool = True,
150158
) -> Union[SchedulerOutput, Tuple]:
151-
159+
"""
160+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
161+
process from the learned model outputs (most often the predicted noise).
162+
163+
Args:
164+
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
165+
timestep (`int`): current discrete timestep in the diffusion chain.
166+
sample (`torch.FloatTensor` or `np.ndarray`):
167+
current instance of sample being created by diffusion process.
168+
eta (`float`): weight of noise for added noise in diffusion step.
169+
use_clipped_model_output (`bool`): TODO
170+
generator: random number generator.
171+
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
172+
173+
Returns:
174+
`SchedulerOutput`: updated sample in the diffusion chain.
175+
176+
"""
152177
if self.num_inference_steps is None:
153178
raise ValueError(
154179
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,13 @@ def __init__(
117117
self.variance_type = variance_type
118118

119119
def set_timesteps(self, num_inference_steps: int):
120+
"""
121+
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
122+
123+
Args:
124+
num_inference_steps (`int`):
125+
the number of diffusion steps used when generating samples with a pre-trained model.
126+
"""
120127
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
121128
self.num_inference_steps = num_inference_steps
122129
self.timesteps = np.arange(
@@ -166,7 +173,25 @@ def step(
166173
generator=None,
167174
return_dict: bool = True,
168175
) -> Union[SchedulerOutput, Tuple]:
169-
176+
"""
177+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
178+
process from the learned model outputs (most often the predicted noise).
179+
180+
Args:
181+
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
182+
timestep (`int`): current discrete timestep in the diffusion chain.
183+
sample (`torch.FloatTensor` or `np.ndarray`):
184+
current instance of sample being created by diffusion process.
185+
eta (`float`): weight of noise for added noise in diffusion step.
186+
predict_epsilon (`bool`):
187+
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
188+
generator: random number generator.
189+
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
190+
191+
Returns:
192+
`SchedulerOutput`: updated sample in the diffusion chain.
193+
194+
"""
170195
t = timestep
171196

172197
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:

src/diffusers/schedulers/scheduling_karras_ve.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ def __init__(
8989
self.set_format(tensor_format=tensor_format)
9090

9191
def set_timesteps(self, num_inference_steps: int):
92+
"""
93+
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
94+
95+
Args:
96+
num_inference_steps (`int`):
97+
the number of diffusion steps used when generating samples with a pre-trained model.
98+
99+
"""
92100
self.num_inference_steps = num_inference_steps
93101
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
94102
self.schedule = [
@@ -105,6 +113,8 @@ def add_noise_to_input(
105113
"""
106114
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
107115
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
116+
117+
TODO Args:
108118
"""
109119
if self.s_min <= sigma <= self.s_max:
110120
gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)
@@ -126,6 +136,21 @@ def step(
126136
sample_hat: Union[torch.FloatTensor, np.ndarray],
127137
return_dict: bool = True,
128138
) -> Union[KarrasVeOutput, Tuple]:
139+
"""
140+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
141+
process from the learned model outputs (most often the predicted noise).
142+
143+
Args:
144+
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
145+
sigma_hat (`float`): TODO
146+
sigma_prev (`float`): TODO
147+
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
148+
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
149+
150+
Returns:
151+
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
152+
153+
"""
129154

130155
pred_original_sample = sample_hat + sigma_hat * model_output
131156
derivative = (sample_hat - pred_original_sample) / sigma_hat
@@ -146,7 +171,22 @@ def step_correct(
146171
derivative: Union[torch.FloatTensor, np.ndarray],
147172
return_dict: bool = True,
148173
) -> Union[KarrasVeOutput, Tuple]:
174+
"""
175+
Correct the predicted sample based on the output model_output of the network. TODO complete description
149176
177+
Args:
178+
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
179+
sigma_hat (`float`): TODO
180+
sigma_prev (`float`): TODO
181+
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
182+
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
183+
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
184+
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
185+
186+
Returns:
187+
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
188+
189+
"""
150190
pred_original_sample = sample_prev + sigma_prev * model_output
151191
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
152192
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,12 @@ def __init__(
8080

8181
def get_lms_coefficient(self, order, t, current_order):
8282
"""
83-
Compute a linear multistep coefficient
83+
Compute a linear multistep coefficient.
84+
85+
Args:
86+
order (TODO):
87+
t (TODO):
88+
current_order (TODO):
8489
"""
8590

8691
def lms_derivative(tau):
@@ -96,6 +101,13 @@ def lms_derivative(tau):
96101
return integrated_coeff
97102

98103
def set_timesteps(self, num_inference_steps: int):
104+
"""
105+
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
106+
107+
Args:
108+
num_inference_steps (`int`):
109+
the number of diffusion steps used when generating samples with a pre-trained model.
110+
"""
99111
self.num_inference_steps = num_inference_steps
100112
self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
101113

@@ -118,6 +130,22 @@ def step(
118130
order: int = 4,
119131
return_dict: bool = True,
120132
) -> Union[SchedulerOutput, Tuple]:
133+
"""
134+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
135+
process from the learned model outputs (most often the predicted noise).
136+
137+
Args:
138+
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
139+
timestep (`int`): current discrete timestep in the diffusion chain.
140+
sample (`torch.FloatTensor` or `np.ndarray`):
141+
current instance of sample being created by diffusion process.
142+
order: coefficient for multi-step inference.
143+
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
144+
145+
Returns:
146+
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
147+
148+
"""
121149
sigma = self.sigmas[timestep]
122150

123151
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,14 @@ def __init__(
127127
self.set_format(tensor_format=tensor_format)
128128

129129
def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
130+
"""
131+
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
132+
133+
Args:
134+
num_inference_steps (`int`):
135+
the number of diffusion steps used when generating samples with a pre-trained model.
136+
offset (`int`): TODO
137+
"""
130138
self.num_inference_steps = num_inference_steps
131139
self._timesteps = list(
132140
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
@@ -164,7 +172,23 @@ def step(
164172
sample: Union[torch.FloatTensor, np.ndarray],
165173
return_dict: bool = True,
166174
) -> Union[SchedulerOutput, Tuple]:
175+
"""
176+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
177+
process from the learned model outputs (most often the predicted noise).
178+
179+
This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
180+
181+
Args:
182+
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
183+
timestep (`int`): current discrete timestep in the diffusion chain.
184+
sample (`torch.FloatTensor` or `np.ndarray`):
185+
current instance of sample being created by diffusion process.
186+
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
187+
188+
Returns:
189+
`SchedulerOutput`: updated sample in the diffusion chain.
167190
191+
"""
168192
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
169193
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
170194
else:
@@ -180,6 +204,17 @@ def step_prk(
180204
"""
181205
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
182206
solution to the differential equation.
207+
208+
Args:
209+
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
210+
timestep (`int`): current discrete timestep in the diffusion chain.
211+
sample (`torch.FloatTensor` or `np.ndarray`):
212+
current instance of sample being created by diffusion process.
213+
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
214+
215+
Returns:
216+
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
217+
183218
"""
184219
if self.num_inference_steps is None:
185220
raise ValueError(
@@ -223,6 +258,17 @@ def step_plms(
223258
"""
224259
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
225260
times to approximate the solution.
261+
262+
Args:
263+
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
264+
timestep (`int`): current discrete timestep in the diffusion chain.
265+
sample (`torch.FloatTensor` or `np.ndarray`):
266+
current instance of sample being created by diffusion process.
267+
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
268+
269+
Returns:
270+
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
271+
226272
"""
227273
if self.num_inference_steps is None:
228274
raise ValueError(

src/diffusers/schedulers/scheduling_sde_ve.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
4747
"""
4848
The variance exploding stochastic differential equation (SDE) scheduler.
4949
50+
For more information, see the original paper: https://arxiv.org/abs/2011.13456
51+
5052
Args:
5153
snr (`float`):
5254
coefficient weighting the step from the model_output sample (from the network) to the random noise.
@@ -80,6 +82,15 @@ def __init__(
8082
self.set_format(tensor_format=tensor_format)
8183

8284
def set_timesteps(self, num_inference_steps, sampling_eps=None):
85+
"""
86+
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
87+
88+
Args:
89+
num_inference_steps (`int`):
90+
the number of diffusion steps used when generating samples with a pre-trained model.
91+
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
92+
93+
"""
8394
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
8495
tensor_format = getattr(self, "tensor_format", "pt")
8596
if tensor_format == "np":
@@ -90,6 +101,20 @@ def set_timesteps(self, num_inference_steps, sampling_eps=None):
90101
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
91102

92103
def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None):
104+
"""
105+
Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.
106+
107+
The sigmas control the weight of the `drift` and `diffusion` components of sample update.
108+
109+
Args:
110+
num_inference_steps (`int`):
111+
the number of diffusion steps used when generating samples with a pre-trained model.
112+
sigma_min (`float`, optional):
113+
initial noise scale value (overrides value given at Scheduler instantiation).
114+
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
115+
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
116+
117+
"""
93118
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
94119
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
95120
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
@@ -141,7 +166,20 @@ def step_pred(
141166
**kwargs,
142167
) -> Union[SdeVeOutput, Tuple]:
143168
"""
144-
Predict the sample at the previous timestep by reversing the SDE.
169+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
170+
process from the learned model outputs (most often the predicted noise).
171+
172+
Args:
173+
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
174+
timestep (`int`): current discrete timestep in the diffusion chain.
175+
sample (`torch.FloatTensor` or `np.ndarray`):
176+
current instance of sample being created by diffusion process.
177+
generator: random number generator.
178+
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
179+
180+
Returns:
181+
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
182+
145183
"""
146184
if "seed" in kwargs and kwargs["seed"] is not None:
147185
self.set_seed(kwargs["seed"])
@@ -187,6 +225,17 @@ def step_correct(
187225
"""
188226
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
189227
after making the prediction for the previous timestep.
228+
229+
Args:
230+
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
231+
sample (`torch.FloatTensor` or `np.ndarray`):
232+
current instance of sample being created by diffusion process.
233+
generator: random number generator.
234+
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
235+
236+
Returns:
237+
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
238+
190239
"""
191240
if "seed" in kwargs and kwargs["seed"] is not None:
192241
self.set_seed(kwargs["seed"])

src/diffusers/schedulers/scheduling_sde_vp.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@
2424

2525

2626
class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
27+
"""
28+
The variance preserving stochastic differential equation (SDE) scheduler.
29+
30+
For more information, see the original paper: https://arxiv.org/abs/2011.13456
31+
32+
UNDER CONSTRUCTION
33+
34+
"""
35+
2736
@register_to_config
2837
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
2938

0 commit comments

Comments
 (0)