Skip to content

Commit e910006

Browse files
anton-lPrathik Rao
authored andcommitted
Add back-compatibility to LMS timesteps (huggingface#750)
* Add back-compatibility to LMS timesteps * style
1 parent 72a52f3 commit e910006

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,6 @@ def step(
202202
When returning a tuple, the first element is the sample tensor.
203203
204204
"""
205-
if not isinstance(timestep, float) and not isinstance(timestep, torch.FloatTensor):
206-
warnings.warn(
207-
f"`LMSDiscreteScheduler` timesteps must be `float` or `torch.FloatTensor`, not {type(timestep)}. "
208-
"Make sure to pass one of the `scheduler.timesteps`"
209-
)
210205
if not self.is_scale_input_called:
211206
warnings.warn(
212207
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
@@ -215,7 +210,18 @@ def step(
215210

216211
if isinstance(timestep, torch.Tensor):
217212
timestep = timestep.to(self.timesteps.device)
218-
step_index = (self.timesteps == timestep).nonzero().item()
213+
if (
214+
isinstance(timestep, int)
215+
or isinstance(timestep, torch.IntTensor)
216+
or isinstance(timestep, torch.LongTensor)
217+
):
218+
warnings.warn(
219+
"Integer timesteps in `LMSDiscreteScheduler.step()` are deprecated and will be removed in version"
220+
" 0.5.0. Make sure to pass one of the `scheduler.timesteps`."
221+
)
222+
step_index = timestep
223+
else:
224+
step_index = (self.timesteps == timestep).nonzero().item()
219225
sigma = self.sigmas[step_index]
220226

221227
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
@@ -250,7 +256,14 @@ def add_noise(
250256
sigmas = self.sigmas.to(original_samples.device)
251257
schedule_timesteps = self.timesteps.to(original_samples.device)
252258
timesteps = timesteps.to(original_samples.device)
253-
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
259+
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
260+
warnings.warn(
261+
"Integer timesteps in `LMSDiscreteScheduler.add_noise()` are deprecated and will be removed in"
262+
" version 0.5.0. Make sure to pass values from `scheduler.timesteps`."
263+
)
264+
step_indices = timesteps
265+
else:
266+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
254267

255268
sigma = sigmas[step_indices].flatten()
256269
while len(sigma.shape) < len(original_samples.shape):

0 commit comments

Comments
 (0)