@@ -434,3 +434,61 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac
434
434
if return_info :
435
435
return x , info
436
436
return x
437
+
438
+
439
+ @torch .no_grad ()
440
+ def sample_dpmpp_2s_ancestral (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. ):
441
+ """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
442
+ extra_args = {} if extra_args is None else extra_args
443
+ s_in = x .new_ones ([x .shape [0 ]])
444
+ sigma_fn = lambda t : t .neg ().exp ()
445
+ t_fn = lambda sigma : sigma .log ().neg ()
446
+
447
+ for i in trange (len (sigmas ) - 1 , disable = disable ):
448
+ denoised = model (x , sigmas [i ] * s_in , ** extra_args )
449
+ sigma_down , sigma_up = get_ancestral_step (sigmas [i ], sigmas [i + 1 ], eta = eta )
450
+ if callback is not None :
451
+ callback ({'x' : x , 'i' : i , 'sigma' : sigmas [i ], 'sigma_hat' : sigmas [i ], 'denoised' : denoised })
452
+ if sigma_down == 0 :
453
+ # Euler method
454
+ d = to_d (x , sigmas [i ], denoised )
455
+ dt = sigma_down - sigmas [i ]
456
+ x = x + d * dt
457
+ else :
458
+ # DPM-Solver-2++(2S)
459
+ t , t_next = t_fn (sigmas [i ]), t_fn (sigma_down )
460
+ r = 1 / 2
461
+ h = t_next - t
462
+ s = t + r * h
463
+ x_2 = (sigma_fn (s ) / sigma_fn (t )) * x - (- h * r ).expm1 () * denoised
464
+ denoised_2 = model (x_2 , sigma_fn (s ) * s_in , ** extra_args )
465
+ x = (sigma_fn (t_next ) / sigma_fn (t )) * x - (- h ).expm1 () * denoised_2
466
+ # Noise addition
467
+ x = x + torch .randn_like (x ) * s_noise * sigma_up
468
+ return x
469
+
470
+
471
+ @torch .no_grad ()
472
+ def sample_dpmpp_2m (model , x , sigmas , extra_args = None , callback = None , disable = None ):
473
+ """DPM-Solver++(2M)."""
474
+ extra_args = {} if extra_args is None else extra_args
475
+ s_in = x .new_ones ([x .shape [0 ]])
476
+ sigma_fn = lambda t : t .neg ().exp ()
477
+ t_fn = lambda sigma : sigma .log ().neg ()
478
+ old_denoised = None
479
+
480
+ for i in trange (len (sigmas ) - 1 , disable = disable ):
481
+ denoised = model (x , sigmas [i ] * s_in , ** extra_args )
482
+ if callback is not None :
483
+ callback ({'x' : x , 'i' : i , 'sigma' : sigmas [i ], 'sigma_hat' : sigmas [i ], 'denoised' : denoised })
484
+ t , t_next = t_fn (sigmas [i ]), t_fn (sigmas [i + 1 ])
485
+ h = t_next - t
486
+ if old_denoised is None or sigmas [i + 1 ] == 0 :
487
+ x = (sigma_fn (t_next ) / sigma_fn (t )) * x - (- h ).expm1 () * denoised
488
+ else :
489
+ h_last = t - t_fn (sigmas [i - 1 ])
490
+ r = h_last / h
491
+ denoised_d = (1 + 1 / (2 * r )) * denoised - (1 / (2 * r )) * old_denoised
492
+ x = (sigma_fn (t_next ) / sigma_fn (t )) * x - (- h ).expm1 () * denoised_d
493
+ old_denoised = denoised
494
+ return x
0 commit comments