Skip to content

Commit 6b09f37

Browse files
authored
[Scheduler design] The pragmatic approach (#719)
* init * improve add_noise * [debug start] run slow test * [debug end] * quick revert * Add docstrings and warnings + API tests * Make the warning less spammy
1 parent 726aba0 commit 6b09f37

13 files changed

+179
-76
lines changed

src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __call__(
5757

5858
model = self.unet
5959

60-
sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max
60+
sample = torch.randn(*shape, generator=generator) * self.scheduler.init_noise_sigma
6161
sample = sample.to(self.device)
6262

6363
self.scheduler.set_timesteps(num_inference_steps)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,8 @@ def __call__(
281281
# It's more optimized to move all timesteps to correct device beforehand
282282
timesteps_tensor = self.scheduler.timesteps.to(self.device)
283283

284-
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
285-
if isinstance(self.scheduler, LMSDiscreteScheduler):
286-
latents = latents * self.scheduler.sigmas[0]
284+
# scale the initial noise by the standard deviation required by the scheduler
285+
latents = latents * self.scheduler.init_noise_sigma
287286

288287
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
289288
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -297,10 +296,7 @@ def __call__(
297296
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
298297
# expand the latents if we are doing classifier free guidance
299298
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
300-
if isinstance(self.scheduler, LMSDiscreteScheduler):
301-
sigma = self.scheduler.sigmas[i]
302-
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
303-
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
299+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
304300

305301
# predict the noise residual
306302
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@@ -311,10 +307,7 @@ def __call__(
311307
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
312308

313309
# compute the previous noisy sample x_t -> x_t-1
314-
if isinstance(self.scheduler, LMSDiscreteScheduler):
315-
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
316-
else:
317-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
310+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
318311

319312
# call the callback, if provided
320313
if callback is not None and i % callback_steps == 0:

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,9 @@ def __call__(
226226
offset = self.scheduler.config.get("steps_offset", 0)
227227
init_timestep = int(num_inference_steps * strength) + offset
228228
init_timestep = min(init_timestep, num_inference_steps)
229-
if isinstance(self.scheduler, LMSDiscreteScheduler):
230-
timesteps = torch.tensor(
231-
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
232-
)
233-
else:
234-
timesteps = self.scheduler.timesteps[-init_timestep]
235-
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
229+
230+
timesteps = self.scheduler.timesteps[-init_timestep]
231+
timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
236232

237233
# add noise to latents using the timesteps
238234
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
@@ -310,16 +306,9 @@ def __call__(
310306
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
311307

312308
for i, t in enumerate(self.progress_bar(timesteps)):
313-
t_index = t_start + i
314-
315309
# expand the latents if we are doing classifier free guidance
316310
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
317-
318-
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
319-
if isinstance(self.scheduler, LMSDiscreteScheduler):
320-
sigma = self.scheduler.sigmas[t_index]
321-
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
322-
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
311+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
323312

324313
# predict the noise residual
325314
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@@ -330,10 +319,7 @@ def __call__(
330319
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
331320

332321
# compute the previous noisy sample x_t -> x_t-1
333-
if isinstance(self.scheduler, LMSDiscreteScheduler):
334-
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
335-
else:
336-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
322+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
337323

338324
# call the callback, if provided
339325
if callback is not None and i % callback_steps == 0:

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -260,13 +260,9 @@ def __call__(
260260
offset = self.scheduler.config.get("steps_offset", 0)
261261
init_timestep = int(num_inference_steps * strength) + offset
262262
init_timestep = min(init_timestep, num_inference_steps)
263-
if isinstance(self.scheduler, LMSDiscreteScheduler):
264-
timesteps = torch.tensor(
265-
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
266-
)
267-
else:
268-
timesteps = self.scheduler.timesteps[-init_timestep]
269-
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
263+
264+
timesteps = self.scheduler.timesteps[-init_timestep]
265+
timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
270266

271267
# add noise to latents using the timesteps
272268
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
@@ -348,13 +344,9 @@ def __call__(
348344
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
349345

350346
for i, t in tqdm(enumerate(timesteps)):
351-
t_index = t_start + i
352347
# expand the latents if we are doing classifier free guidance
353348
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
354-
if isinstance(self.scheduler, LMSDiscreteScheduler):
355-
sigma = self.scheduler.sigmas[t_index]
356-
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
357-
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
349+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
358350

359351
# predict the noise residual
360352
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@@ -365,14 +357,9 @@ def __call__(
365357
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
366358

367359
# compute the previous noisy sample x_t -> x_t-1
368-
if isinstance(self.scheduler, LMSDiscreteScheduler):
369-
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
370-
# masking
371-
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index]))
372-
else:
373-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
374-
# masking
375-
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t]))
360+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
361+
# masking
362+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
376363

377364
latents = (init_latents_proper * mask) + (latents * (1 - mask))
378365

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,7 @@ def __call__(
147147
# set timesteps
148148
self.scheduler.set_timesteps(num_inference_steps)
149149

150-
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
151-
if isinstance(self.scheduler, LMSDiscreteScheduler):
152-
latents = latents * self.scheduler.sigmas[0]
150+
latents = latents * self.scheduler.init_noise_sigma
153151

154152
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
155153
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -163,10 +161,7 @@ def __call__(
163161
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
164162
# expand the latents if we are doing classifier free guidance
165163
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
166-
if isinstance(self.scheduler, LMSDiscreteScheduler):
167-
sigma = self.scheduler.sigmas[i]
168-
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
169-
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
164+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
170165

171166
# predict the noise residual
172167
noise_pred = self.unet(
@@ -180,11 +175,7 @@ def __call__(
180175
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
181176

182177
# compute the previous noisy sample x_t -> x_t-1
183-
if isinstance(self.scheduler, LMSDiscreteScheduler):
184-
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
185-
else:
186-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
187-
178+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
188179
latents = np.array(latents)
189180

190181
# call the callback, if provided

src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __call__(
6969
model = self.unet
7070

7171
# sample x_0 ~ N(0, sigma_0^2 * I)
72-
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
72+
sample = torch.randn(*shape) * self.scheduler.init_noise_sigma
7373
sample = sample.to(self.device)
7474

7575
self.scheduler.set_timesteps(num_inference_steps)

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,27 @@ def __init__(
152152
# whether we use the final alpha of the "non-previous" one.
153153
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
154154

155+
# standard deviation of the initial noise distribution
156+
self.init_noise_sigma = 1.0
157+
155158
# setable values
156159
self.num_inference_steps = None
157160
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
158161

162+
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
163+
"""
164+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
165+
current timestep.
166+
167+
Args:
168+
sample (`torch.FloatTensor`): input sample
169+
timestep (`int`, optional): current timestep
170+
171+
Returns:
172+
`torch.FloatTensor`: scaled input sample
173+
"""
174+
return sample
175+
159176
def _get_variance(self, timestep, prev_timestep):
160177
alpha_prod_t = self.alphas_cumprod[timestep]
161178
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,29 @@ def __init__(
140140
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
141141
self.one = torch.tensor(1.0)
142142

143+
# standard deviation of the initial noise distribution
144+
self.init_noise_sigma = 1.0
145+
143146
# setable values
144147
self.num_inference_steps = None
145148
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
146149

147150
self.variance_type = variance_type
148151

152+
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
153+
"""
154+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
155+
current timestep.
156+
157+
Args:
158+
sample (`torch.FloatTensor`): input sample
159+
timestep (`int`, optional): current timestep
160+
161+
Returns:
162+
`torch.FloatTensor`: scaled input sample
163+
"""
164+
return sample
165+
149166
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
150167
"""
151168
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

src/diffusers/schedulers/scheduling_karras_ve.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,28 @@ def __init__(
9595
take_from=kwargs,
9696
)
9797

98+
# standard deviation of the initial noise distribution
99+
self.init_noise_sigma = sigma_max
100+
98101
# setable values
99102
self.num_inference_steps: int = None
100103
self.timesteps: np.IntTensor = None
101104
self.schedule: torch.FloatTensor = None # sigma(t_i)
102105

106+
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
107+
"""
108+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
109+
current timestep.
110+
111+
Args:
112+
sample (`torch.FloatTensor`): input sample
113+
timestep (`int`, optional): current timestep
114+
115+
Returns:
116+
`torch.FloatTensor`: scaled input sample
117+
"""
118+
return sample
119+
103120
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
104121
"""
105122
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import warnings
1515
from dataclasses import dataclass
1616
from typing import Optional, Tuple, Union
1717

@@ -102,11 +102,36 @@ def __init__(
102102
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
103103
self.sigmas = torch.from_numpy(sigmas)
104104

105+
# standard deviation of the initial noise distribution
106+
self.init_noise_sigma = self.sigmas.max()
107+
105108
# setable values
106109
self.num_inference_steps = None
107110
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
108111
self.timesteps = torch.from_numpy(timesteps)
109112
self.derivatives = []
113+
self.is_scale_input_called = False
114+
115+
def scale_model_input(
116+
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
117+
) -> torch.FloatTensor:
118+
"""
119+
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
120+
121+
Args:
122+
sample (`torch.FloatTensor`): input sample
123+
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
124+
125+
Returns:
126+
`torch.FloatTensor`: scaled input sample
127+
"""
128+
if isinstance(timestep, torch.Tensor):
129+
timestep = timestep.to(self.timesteps.device)
130+
step_index = (self.timesteps == timestep).nonzero().item()
131+
sigma = self.sigmas[step_index]
132+
sample = sample / ((sigma**2 + 1) ** 0.5)
133+
self.is_scale_input_called = True
134+
return sample
110135

111136
def get_lms_coefficient(self, order, t, current_order):
112137
"""
@@ -154,7 +179,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
154179
def step(
155180
self,
156181
model_output: torch.FloatTensor,
157-
timestep: int,
182+
timestep: Union[float, torch.FloatTensor],
158183
sample: torch.FloatTensor,
159184
order: int = 4,
160185
return_dict: bool = True,
@@ -165,7 +190,7 @@ def step(
165190
166191
Args:
167192
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
168-
timestep (`int`): current discrete timestep in the diffusion chain.
193+
timestep (`float`): current timestep in the diffusion chain.
169194
sample (`torch.FloatTensor`):
170195
current instance of sample being created by diffusion process.
171196
order: coefficient for multi-step inference.
@@ -177,7 +202,21 @@ def step(
177202
When returning a tuple, the first element is the sample tensor.
178203
179204
"""
180-
sigma = self.sigmas[timestep]
205+
if not isinstance(timestep, float) and not isinstance(timestep, torch.FloatTensor):
206+
warnings.warn(
207+
f"`LMSDiscreteScheduler` timesteps must be `float` or `torch.FloatTensor`, not {type(timestep)}. "
208+
"Make sure to pass one of the `scheduler.timesteps`"
209+
)
210+
if not self.is_scale_input_called:
211+
warnings.warn(
212+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
213+
"See `StableDiffusionPipeline` for a usage example."
214+
)
215+
216+
if isinstance(timestep, torch.Tensor):
217+
timestep = timestep.to(self.timesteps.device)
218+
step_index = (self.timesteps == timestep).nonzero().item()
219+
sigma = self.sigmas[step_index]
181220

182221
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
183222
pred_original_sample = sample - sigma * model_output
@@ -189,8 +228,8 @@ def step(
189228
self.derivatives.pop(0)
190229

191230
# 3. Compute linear multistep coefficients
192-
order = min(timestep + 1, order)
193-
lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)]
231+
order = min(step_index + 1, order)
232+
lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)]
194233

195234
# 4. Compute previous sample based on the derivatives path
196235
prev_sample = sample + sum(
@@ -206,12 +245,14 @@ def add_noise(
206245
self,
207246
original_samples: torch.FloatTensor,
208247
noise: torch.FloatTensor,
209-
timesteps: torch.IntTensor,
248+
timesteps: torch.FloatTensor,
210249
) -> torch.FloatTensor:
211250
sigmas = self.sigmas.to(original_samples.device)
251+
schedule_timesteps = self.timesteps.to(original_samples.device)
212252
timesteps = timesteps.to(original_samples.device)
253+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
213254

214-
sigma = sigmas[timesteps].flatten()
255+
sigma = sigmas[step_indices].flatten()
215256
while len(sigma.shape) < len(original_samples.shape):
216257
sigma = sigma.unsqueeze(-1)
217258

0 commit comments

Comments
 (0)