Skip to content

Adding pred_original_sample to SchedulerOutput for some samplers #614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 22, 2022
Merged

Adding pred_original_sample to SchedulerOutput for some samplers #614

merged 8 commits into from
Sep 22, 2022

Conversation

johnowhitaker
Copy link
Contributor

Modified DDPMScheduler, DDIMScheduler, LMSDiscreteScheduler, KarrasVeScheduler step methods so we can access the predicted denoised outputs. SchedulerOutput (or KarrasVeOutput in the case of KarrasVeScheduler) now has pred_original_sample in addition to prev_sample.

I've checked that all these work (fortunately the changes aren't doing anything too drastic since all these already calculate pred_original_sample internal to their step functions). But I haven't checked to make sure there aren't others using SchedulerOutput - just in case I made pred_original_sample Optional and set default to None.

I haven't touched the flax versions since I'd need to brush up on my jax

Quickest way to test (example DDIM) is as follows:

from diffusers import UNet2DModel
from diffusers import DDPMScheduler, DDIMScheduler, LMSDiscreteScheduler, KarrasVeScheduler
import torch
repo_id = "google/ddpm-celebahq-256"
model = UNet2DModel.from_pretrained(repo_id)
scheduler = DDIMScheduler.from_config(repo_id)
scheduler.set_timesteps(40)

# Random noise as our starting point
noisy_sample = torch.randn(
    1, model.config.in_channels, model.config.sample_size, model.config.sample_size
)
print('Noisy input shape:', noisy_sample.shape)

# Getting the model prediction for a given timestep (it predicts the noise residual)
with torch.no_grad():
    noisy_residual = model(sample=noisy_sample, timestep=2).sample
print('Model output shape:', noisy_residual.shape)

# Use this prediction to 'step' with the scheduler
less_noisy_sample = scheduler.step(
    model_output=noisy_residual, timestep=2, sample=noisy_sample
).prev_sample
print('Next step:', less_noisy_sample.shape)

# And check it has pred_original_sample
pred_t0 = scheduler.step(
    model_output=noisy_residual, timestep=2, sample=noisy_sample
).pred_original_sample
print('Pred og sample:', pred_t0.shape)

pndm, sde_ve and sde_vp are a little different, so left them alone for now.

…Scheduler, LMSDiscreteScheduler, KarrasVeScheduler step methods so we can access the predicted denoised outputs
@johnowhitaker
Copy link
Contributor Author

Just checked that e.g. PNDMScheduler (which also uses SchedulerOutput) still works as expected, where scheculer_output.prev_sample is the same as before, and scheculer_output.pred_original_sample = None. So the addition of pred_original_sample to the SchedulerOutput class won't break things that don't compute it.
Between this and diffusers/tests/test_pipelines.py passing still, I'm fairly confident I haven't introduced any breaking changes!

@johnowhitaker
Copy link
Contributor Author

Not sure why the code check on src/diffusers/schedulers/scheduling_karras_ve.py is failing?

@anton-l
Copy link
Member

anton-l commented Sep 22, 2022

Not sure why the code check on src/diffusers/schedulers/scheduling_karras_ve.py is failing?

Running make style should fix that :)

…output dataclasses so the default SchedulerOutput in scheduling_utils does not need pred_original_sample as an optional extra
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 22, 2022

The documentation is not available anymore as the PR was closed or merged.

@johnowhitaker
Copy link
Contributor Author

Great, have undone the change to the default SchedulerOutput and now DDPMScheduler, DDIMScheduler and LMSDiscreteScheduler all have their own output classes defined in their scheduler code files.

@johnowhitaker
Copy link
Contributor Author

Strange the prev one passed everything but after that change its back to failing the code check, but if I run black --check --preview examples tests src utils scripts the output is
All done! ✨ 🍰 ✨
119 files would be left unchanged.
And make style makes no changes.

@johnowhitaker
Copy link
Contributor Author

Think I figured it out, needed to pip install isort, black and hf-doc-builder for make style to be able to fully work its magic.

Copy link
Member

@anton-l anton-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thank you @johnowhitaker!

@anton-l anton-l merged commit 91db818 into huggingface:main Sep 22, 2022
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…gingface#614)

* Adding pred_original_sample to SchedulerOutput of DDPMScheduler, DDIMScheduler, LMSDiscreteScheduler, KarrasVeScheduler step methods so we can access the predicted denoised outputs

* Gave DDPMScheduler, DDIMScheduler and LMSDiscreteScheduler their own output dataclasses so the default SchedulerOutput in scheduling_utils does not need pred_original_sample as an optional extra

* Reordered library imports to follow standard

* didnt get import order quite right apparently

* Forgot to change name of LMSDiscreteSchedulerOutput

* Aha, needed some extra libs for make style to fully work
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants