Skip to content

Add back-compatibility to LMS timesteps #750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 6, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,6 @@ def step(
When returning a tuple, the first element is the sample tensor.

"""
if not isinstance(timestep, float) and not isinstance(timestep, torch.FloatTensor):
warnings.warn(
f"`LMSDiscreteScheduler` timesteps must be `float` or `torch.FloatTensor`, not {type(timestep)}. "
"Make sure to pass one of the `scheduler.timesteps`"
)
if not self.is_scale_input_called:
warnings.warn(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
Expand All @@ -215,7 +210,18 @@ def step(

if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
if (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could lead to some nasty silent bugs though if the "new" passed timesteps are ints instead of float no? Not sure whether this is a good idea tbh

isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
warnings.warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the deprecate functionality here: #659

"Integer timesteps in `LMSDiscreteScheduler.step()` are deprecated and will be removed in version"
" 0.5.0. Make sure to pass one of the `scheduler.timesteps`."
)
step_index = timestep
else:
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
Expand Down Expand Up @@ -250,7 +256,14 @@ def add_noise(
sigmas = self.sigmas.to(original_samples.device)
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
warnings.warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the deprecate functionality here

"Integer timesteps in `LMSDiscreteScheduler.add_noise()` are deprecated and will be removed in"
" version 0.5.0. Make sure to pass values from `scheduler.timesteps`."
)
step_indices = timesteps
else:
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
Expand Down