Skip to content

Commit 3f1861e

Browse files
author
Nathan Lambert
authored
hotfix for pdnm test (#220)
1 parent 6a03060 commit 3f1861e

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

tests/test_scheduler.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -426,16 +426,18 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
426426
scheduler = scheduler_class(**scheduler_config)
427427
scheduler.set_timesteps(num_inference_steps)
428428

429-
# copy over dummy past residuals
429+
# copy over dummy past residuals (must be after setting timesteps)
430430
scheduler.ets = dummy_past_residuals[:]
431431

432432
with tempfile.TemporaryDirectory() as tmpdirname:
433433
scheduler.save_config(tmpdirname)
434434
new_scheduler = scheduler_class.from_config(tmpdirname)
435435
# copy over dummy past residuals
436-
new_scheduler.ets = dummy_past_residuals[:]
437436
new_scheduler.set_timesteps(num_inference_steps)
438437

438+
# copy over dummy past residual (must be after setting timesteps)
439+
new_scheduler.ets = dummy_past_residuals[:]
440+
439441
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
440442
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
441443

@@ -461,19 +463,19 @@ def test_pytorch_equal_numpy(self):
461463

462464
scheduler_config = self.get_scheduler_config()
463465
scheduler = scheduler_class(tensor_format="np", **scheduler_config)
464-
# copy over dummy past residuals
465-
scheduler.ets = dummy_past_residuals[:]
466466

467467
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
468-
# copy over dummy past residuals
469-
scheduler_pt.ets = dummy_past_residuals_pt[:]
470468

471469
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
472470
scheduler.set_timesteps(num_inference_steps)
473471
scheduler_pt.set_timesteps(num_inference_steps)
474472
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
475473
kwargs["num_inference_steps"] = num_inference_steps
476474

475+
# copy over dummy past residuals (must be done after set_timesteps)
476+
scheduler.ets = dummy_past_residuals[:]
477+
scheduler_pt.ets = dummy_past_residuals_pt[:]
478+
477479
output = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
478480
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
479481
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
@@ -494,15 +496,16 @@ def test_step_shape(self):
494496

495497
sample = self.dummy_sample
496498
residual = 0.1 * sample
497-
# copy over dummy past residuals
498-
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
499-
scheduler.ets = dummy_past_residuals[:]
500499

501500
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
502501
scheduler.set_timesteps(num_inference_steps)
503502
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
504503
kwargs["num_inference_steps"] = num_inference_steps
505504

505+
# copy over dummy past residuals (must be done after set_timesteps)
506+
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
507+
scheduler.ets = dummy_past_residuals[:]
508+
506509
output_0 = scheduler.step_prk(residual, 0, sample, **kwargs)["prev_sample"]
507510
output_1 = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
508511

0 commit comments

Comments
 (0)