Skip to content

Commit b52684c

Browse files
authored
Add exponential sigmas to other schedulers and update docs (#9518)
1 parent bac8a24 commit b52684c

13 files changed

+345
-9
lines changed

docs/source/en/api/schedulers/overview.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,12 @@ Many schedulers are implemented from the [k-diffusion](https://github.com/crowso
4646
| N/A | [`UniPCMultistepScheduler`] | |
4747

4848
## Noise schedules and schedule types
49-
| A1111/k-diffusion | 🤗 Diffusers |
50-
|---------------------|----------------------------------------|
51-
| Karras | init with `use_karras_sigmas=True` |
52-
| sgm_uniform | init with `timestep_spacing="trailing"`|
53-
| simple | init with `timestep_spacing="trailing"`|
49+
| A1111/k-diffusion | 🤗 Diffusers |
50+
|--------------------------|----------------------------------------------------------------------------|
51+
| Karras | init with `use_karras_sigmas=True` |
52+
| sgm_uniform | init with `timestep_spacing="trailing"` |
53+
| simple | init with `timestep_spacing="trailing"` |
54+
| exponential | init with `timestep_spacing="linspace"`, `use_exponential_sigmas=True` |
5455

5556
All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers.
5657

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
111111
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
112112
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
113113
the sigmas are determined according to a sequence of noise levels {σi}.
114+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
115+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
114116
timestep_spacing (`str`, defaults to `"linspace"`):
115117
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
116118
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -138,9 +140,12 @@ def __init__(
138140
solver_type: str = "logrho",
139141
lower_order_final: bool = True,
140142
use_karras_sigmas: Optional[bool] = False,
143+
use_exponential_sigmas: Optional[bool] = False,
141144
timestep_spacing: str = "linspace",
142145
steps_offset: int = 0,
143146
):
147+
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
148+
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
144149
if trained_betas is not None:
145150
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
146151
elif beta_schedule == "linear":
@@ -255,6 +260,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
255260
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
256261
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
257262
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
263+
elif self.config.use_exponential_sigmas:
264+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
265+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
258266
else:
259267
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
260268
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -366,6 +374,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
366374
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
367375
return sigmas
368376

377+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
378+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
379+
"""Constructs an exponential noise schedule."""
380+
381+
# Hack to make sure that other schedulers which copy this function don't break
382+
# TODO: Add this logic to the other schedulers
383+
if hasattr(self.config, "sigma_min"):
384+
sigma_min = self.config.sigma_min
385+
else:
386+
sigma_min = None
387+
388+
if hasattr(self.config, "sigma_max"):
389+
sigma_max = self.config.sigma_max
390+
else:
391+
sigma_max = None
392+
393+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
394+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
395+
396+
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
397+
return sigmas
398+
369399
def convert_model_output(
370400
self,
371401
model_output: torch.Tensor,

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
161161
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
162162
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
163163
the sigmas are determined according to a sequence of noise levels {σi}.
164+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
165+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
164166
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
165167
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
166168
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
@@ -206,6 +208,7 @@ def __init__(
206208
lower_order_final: bool = True,
207209
euler_at_final: bool = False,
208210
use_karras_sigmas: Optional[bool] = False,
211+
use_exponential_sigmas: Optional[bool] = False,
209212
use_lu_lambdas: Optional[bool] = False,
210213
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
211214
lambda_min_clipped: float = -float("inf"),
@@ -214,6 +217,8 @@ def __init__(
214217
steps_offset: int = 0,
215218
rescale_betas_zero_snr: bool = False,
216219
):
220+
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
221+
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
217222
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
218223
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
219224
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
@@ -330,6 +335,8 @@ def set_timesteps(
330335
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
331336
if timesteps is not None and self.config.use_lu_lambdas:
332337
raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
338+
if timesteps is not None and self.config.use_exponential_sigmas:
339+
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
333340

334341
if timesteps is not None:
335342
timesteps = np.array(timesteps).astype(np.int64)
@@ -378,6 +385,9 @@ def set_timesteps(
378385
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
379386
sigmas = np.exp(lambdas)
380387
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
388+
elif self.config.use_exponential_sigmas:
389+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
390+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
381391
else:
382392
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
383393

@@ -510,6 +520,28 @@ def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch
510520
lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
511521
return lambdas
512522

523+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
524+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
525+
"""Constructs an exponential noise schedule."""
526+
527+
# Hack to make sure that other schedulers which copy this function don't break
528+
# TODO: Add this logic to the other schedulers
529+
if hasattr(self.config, "sigma_min"):
530+
sigma_min = self.config.sigma_min
531+
else:
532+
sigma_min = None
533+
534+
if hasattr(self.config, "sigma_max"):
535+
sigma_max = self.config.sigma_max
536+
else:
537+
sigma_max = None
538+
539+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
540+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
541+
542+
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
543+
return sigmas
544+
513545
def convert_model_output(
514546
self,
515547
model_output: torch.Tensor,

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
124124
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
125125
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
126126
the sigmas are determined according to a sequence of noise levels {σi}.
127+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
128+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
127129
lambda_min_clipped (`float`, defaults to `-inf`):
128130
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
129131
cosine (`squaredcos_cap_v2`) noise schedule.
@@ -158,11 +160,14 @@ def __init__(
158160
lower_order_final: bool = True,
159161
euler_at_final: bool = False,
160162
use_karras_sigmas: Optional[bool] = False,
163+
use_exponential_sigmas: Optional[bool] = False,
161164
lambda_min_clipped: float = -float("inf"),
162165
variance_type: Optional[str] = None,
163166
timestep_spacing: str = "linspace",
164167
steps_offset: int = 0,
165168
):
169+
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
170+
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
166171
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
167172
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
168173
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
@@ -213,6 +218,7 @@ def __init__(
213218
self._step_index = None
214219
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
215220
self.use_karras_sigmas = use_karras_sigmas
221+
self.use_exponential_sigmas = use_exponential_sigmas
216222

217223
@property
218224
def step_index(self):
@@ -267,6 +273,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
267273
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
268274
timesteps = timesteps.copy().astype(np.int64)
269275
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
276+
elif self.config.use_exponential_sigmas:
277+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
278+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
270279
else:
271280
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
272281
sigma_max = (
@@ -385,6 +394,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
385394
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
386395
return sigmas
387396

397+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
398+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
399+
"""Constructs an exponential noise schedule."""
400+
401+
# Hack to make sure that other schedulers which copy this function don't break
402+
# TODO: Add this logic to the other schedulers
403+
if hasattr(self.config, "sigma_min"):
404+
sigma_min = self.config.sigma_min
405+
else:
406+
sigma_min = None
407+
408+
if hasattr(self.config, "sigma_max"):
409+
sigma_max = self.config.sigma_max
410+
else:
411+
sigma_max = None
412+
413+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
414+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
415+
416+
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
417+
return sigmas
418+
388419
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
389420
def convert_model_output(
390421
self,

src/diffusers/schedulers/scheduling_dpmsolver_sde.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
160160
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
161161
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
162162
the sigmas are determined according to a sequence of noise levels {σi}.
163+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
164+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
163165
noise_sampler_seed (`int`, *optional*, defaults to `None`):
164166
The random seed to use for the noise sampler. If `None`, a random seed is generated.
165167
timestep_spacing (`str`, defaults to `"linspace"`):
@@ -182,10 +184,13 @@ def __init__(
182184
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
183185
prediction_type: str = "epsilon",
184186
use_karras_sigmas: Optional[bool] = False,
187+
use_exponential_sigmas: Optional[bool] = False,
185188
noise_sampler_seed: Optional[int] = None,
186189
timestep_spacing: str = "linspace",
187190
steps_offset: int = 0,
188191
):
192+
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
193+
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
189194
if trained_betas is not None:
190195
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
191196
elif beta_schedule == "linear":
@@ -341,6 +346,9 @@ def set_timesteps(
341346
if self.config.use_karras_sigmas:
342347
sigmas = self._convert_to_karras(in_sigmas=sigmas)
343348
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
349+
elif self.config.use_exponential_sigmas:
350+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
351+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
344352

345353
second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas)
346354

@@ -421,6 +429,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
421429
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
422430
return sigmas
423431

432+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
433+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
434+
"""Constructs an exponential noise schedule."""
435+
436+
# Hack to make sure that other schedulers which copy this function don't break
437+
# TODO: Add this logic to the other schedulers
438+
if hasattr(self.config, "sigma_min"):
439+
sigma_min = self.config.sigma_min
440+
else:
441+
sigma_min = None
442+
443+
if hasattr(self.config, "sigma_max"):
444+
sigma_max = self.config.sigma_max
445+
else:
446+
sigma_max = None
447+
448+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
449+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
450+
451+
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
452+
return sigmas
453+
424454
@property
425455
def state_in_first_order(self):
426456
return self.sample is None

0 commit comments

Comments
 (0)