Skip to content

[Scheduler design] The pragmatic approach #719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 5, 2022

Conversation

anton-l
Copy link
Member

@anton-l anton-l commented Oct 4, 2022

This schedulers API redesign addresses concerns raised in #336 and starts to make the schedulers interchangeable without making scheduler class-dependent customizations to pipelines.

  1. Now every scheduler contains an init_noise_sigma parameter to scale the normal distribution of the initial noise. While it is just 1.0 for DDPM, DDIM and PNDM, it is customized for the VE and K-LMS schedulers.
    Example usage:
sample = torch.randn(*shape, generator=generator) * self.scheduler.init_noise_sigma
  1. Every scheduler needs to implement the scale_model_input(sample, timestep) (even if it just returns the sample) that scales the denoising model's input based on the current timestep. The method should be called before every model() call.
    Example usage:
sample = self.scheduler.scale_model_input(sample, t)
output = model(sample, t)

Note: the decision to not make it a base class method is intentional, as suggested by @patrickvonplaten

Closes #336

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 4, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, love this API design - think it's slim and should solve 99% of our problems! Actually maybe no need after all to have a SchedulerType.CONTINUOUS then after all in a first step 😍

@anton-l
Copy link
Member Author

anton-l commented Oct 4, 2022

The slow tests pass ✔️

timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
Copy link
Member Author

@anton-l anton-l Oct 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really dislike what we have to do here, but unfortunately there's no good vectorized alternative to search for multiple indices and keep the order the same.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think it's that bad honestly

@patrickvonplaten
Copy link
Contributor

Really cool - like this new design a lot!

Think we can merge this tomorrow 😍

Some final TODOs / suggestions:

  • add the design now to all other schedulers and their scripts (or leave as clear TODOs)
  • think we could add a test that forces every new scheduler to have:
    • a init_noise_sigma variable
    • a scale_model_input function
    • a step function
    • => this could ensure that we follow this design in the future. If some schedulers don't follow this design, let's exempt them for this test maybe for now with a big TODO to fix it
  • Give scale_model_input nice docstrings and make sure it's displayed well in the docs. Also let's maybe add a short comment to DDIM, PNDM and DDPM stating that they don't need model scaling
  • Add a big ⚠️ ⚠️ to this PR that it's backwards breaking and let's maybe try to make it easy for users to fix their code by:
    • If the first timestep that is passed to LMS is an int and is 0 => then it's very likely to be wrong (let's throw a warning here)
    • If LMS step function is used before having called scale_model_input it's most likely wrong (let's throw an maybe even an error here)

Comment on lines +210 to +214
if not self.is_scale_input_called:
warnings.warn(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will pop up in existing community pipelines but won't break them like an exception would. The legacy pipelines can continue using the manual scaling code 👍

@@ -226,6 +226,27 @@ def recursive_check(tuple_object, dict_object):

recursive_check(outputs_tuple, outputs_dict)

def test_scheduler_public_api(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Happy to merge and help you update existing notebooks / docs / blog post now:

  • Update PR description that states exactly what people have to change if they have written their own custom loop. E.g. If you have been using the K-LMS scheduler, please make sure to do the following:
  • If you have been using other schedulers, no need to change anything, but we recomend for generality to always make use of init_sigma and scale_model_input
  • All blog posts
  • All notebooks
  • All training examples

@anton-l anton-l merged commit 6b09f37 into main Oct 5, 2022
@patrickvonplaten patrickvonplaten deleted the scheduler-refactor-pragmatic branch October 5, 2022 13:13
prathikr pushed a commit to prathikr/diffusers that referenced this pull request Oct 26, 2022
* init

* improve add_noise

* [debug start] run slow test

* [debug end]

* quick revert

* Add docstrings and warnings + API tests

* Make the warning less spammy
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* init

* improve add_noise

* [debug start] run slow test

* [debug end]

* quick revert

* Add docstrings and warnings + API tests

* Make the warning less spammy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

scheduler leaky abstractions in pipelines
3 participants