Skip to content

Commit 2d2240b

Browse files
committed
Add DPM-Solver++(2S) and (2M)
1 parent f4e9985 commit 2d2240b

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

k_diffusion/sampling.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,3 +434,61 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac
434434
if return_info:
435435
return x, info
436436
return x
437+
438+
439+
@torch.no_grad()
440+
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1.):
441+
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
442+
extra_args = {} if extra_args is None else extra_args
443+
s_in = x.new_ones([x.shape[0]])
444+
sigma_fn = lambda t: t.neg().exp()
445+
t_fn = lambda sigma: sigma.log().neg()
446+
447+
for i in trange(len(sigmas) - 1, disable=disable):
448+
denoised = model(x, sigmas[i] * s_in, **extra_args)
449+
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
450+
if callback is not None:
451+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
452+
if sigma_down == 0:
453+
# Euler method
454+
d = to_d(x, sigmas[i], denoised)
455+
dt = sigma_down - sigmas[i]
456+
x = x + d * dt
457+
else:
458+
# DPM-Solver-2++(2S)
459+
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
460+
r = 1 / 2
461+
h = t_next - t
462+
s = t + r * h
463+
x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
464+
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
465+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
466+
# Noise addition
467+
x = x + torch.randn_like(x) * s_noise * sigma_up
468+
return x
469+
470+
471+
@torch.no_grad()
472+
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
473+
"""DPM-Solver++(2M)."""
474+
extra_args = {} if extra_args is None else extra_args
475+
s_in = x.new_ones([x.shape[0]])
476+
sigma_fn = lambda t: t.neg().exp()
477+
t_fn = lambda sigma: sigma.log().neg()
478+
old_denoised = None
479+
480+
for i in trange(len(sigmas) - 1, disable=disable):
481+
denoised = model(x, sigmas[i] * s_in, **extra_args)
482+
if callback is not None:
483+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
484+
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
485+
h = t_next - t
486+
if old_denoised is None or sigmas[i + 1] == 0:
487+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
488+
else:
489+
h_last = t - t_fn(sigmas[i - 1])
490+
r = h_last / h
491+
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
492+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
493+
old_denoised = denoised
494+
return x

0 commit comments

Comments
 (0)