@@ -426,16 +426,18 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
426
426
scheduler = scheduler_class (** scheduler_config )
427
427
scheduler .set_timesteps (num_inference_steps )
428
428
429
- # copy over dummy past residuals
429
+ # copy over dummy past residuals (must be after setting timesteps)
430
430
scheduler .ets = dummy_past_residuals [:]
431
431
432
432
with tempfile .TemporaryDirectory () as tmpdirname :
433
433
scheduler .save_config (tmpdirname )
434
434
new_scheduler = scheduler_class .from_config (tmpdirname )
435
435
# copy over dummy past residuals
436
- new_scheduler .ets = dummy_past_residuals [:]
437
436
new_scheduler .set_timesteps (num_inference_steps )
438
437
438
+ # copy over dummy past residual (must be after setting timesteps)
439
+ new_scheduler .ets = dummy_past_residuals [:]
440
+
439
441
output = scheduler .step_prk (residual , time_step , sample , ** kwargs )["prev_sample" ]
440
442
new_output = new_scheduler .step_prk (residual , time_step , sample , ** kwargs )["prev_sample" ]
441
443
@@ -461,19 +463,19 @@ def test_pytorch_equal_numpy(self):
461
463
462
464
scheduler_config = self .get_scheduler_config ()
463
465
scheduler = scheduler_class (tensor_format = "np" , ** scheduler_config )
464
- # copy over dummy past residuals
465
- scheduler .ets = dummy_past_residuals [:]
466
466
467
467
scheduler_pt = scheduler_class (tensor_format = "pt" , ** scheduler_config )
468
- # copy over dummy past residuals
469
- scheduler_pt .ets = dummy_past_residuals_pt [:]
470
468
471
469
if num_inference_steps is not None and hasattr (scheduler , "set_timesteps" ):
472
470
scheduler .set_timesteps (num_inference_steps )
473
471
scheduler_pt .set_timesteps (num_inference_steps )
474
472
elif num_inference_steps is not None and not hasattr (scheduler , "set_timesteps" ):
475
473
kwargs ["num_inference_steps" ] = num_inference_steps
476
474
475
+ # copy over dummy past residuals (must be done after set_timesteps)
476
+ scheduler .ets = dummy_past_residuals [:]
477
+ scheduler_pt .ets = dummy_past_residuals_pt [:]
478
+
477
479
output = scheduler .step_prk (residual , 1 , sample , ** kwargs )["prev_sample" ]
478
480
output_pt = scheduler_pt .step_prk (residual_pt , 1 , sample_pt , ** kwargs )["prev_sample" ]
479
481
assert np .sum (np .abs (output - output_pt .numpy ())) < 1e-4 , "Scheduler outputs are not identical"
@@ -494,15 +496,16 @@ def test_step_shape(self):
494
496
495
497
sample = self .dummy_sample
496
498
residual = 0.1 * sample
497
- # copy over dummy past residuals
498
- dummy_past_residuals = [residual + 0.2 , residual + 0.15 , residual + 0.1 , residual + 0.05 ]
499
- scheduler .ets = dummy_past_residuals [:]
500
499
501
500
if num_inference_steps is not None and hasattr (scheduler , "set_timesteps" ):
502
501
scheduler .set_timesteps (num_inference_steps )
503
502
elif num_inference_steps is not None and not hasattr (scheduler , "set_timesteps" ):
504
503
kwargs ["num_inference_steps" ] = num_inference_steps
505
504
505
+ # copy over dummy past residuals (must be done after set_timesteps)
506
+ dummy_past_residuals = [residual + 0.2 , residual + 0.15 , residual + 0.1 , residual + 0.05 ]
507
+ scheduler .ets = dummy_past_residuals [:]
508
+
506
509
output_0 = scheduler .step_prk (residual , 0 , sample , ** kwargs )["prev_sample" ]
507
510
output_1 = scheduler .step_prk (residual , 1 , sample , ** kwargs )["prev_sample" ]
508
511
0 commit comments