@@ -202,11 +202,6 @@ def step(
202
202
When returning a tuple, the first element is the sample tensor.
203
203
204
204
"""
205
- if not isinstance (timestep , float ) and not isinstance (timestep , torch .FloatTensor ):
206
- warnings .warn (
207
- f"`LMSDiscreteScheduler` timesteps must be `float` or `torch.FloatTensor`, not { type (timestep )} . "
208
- "Make sure to pass one of the `scheduler.timesteps`"
209
- )
210
205
if not self .is_scale_input_called :
211
206
warnings .warn (
212
207
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
@@ -215,7 +210,18 @@ def step(
215
210
216
211
if isinstance (timestep , torch .Tensor ):
217
212
timestep = timestep .to (self .timesteps .device )
218
- step_index = (self .timesteps == timestep ).nonzero ().item ()
213
+ if (
214
+ isinstance (timestep , int )
215
+ or isinstance (timestep , torch .IntTensor )
216
+ or isinstance (timestep , torch .LongTensor )
217
+ ):
218
+ warnings .warn (
219
+ "Integer timesteps in `LMSDiscreteScheduler.step()` are deprecated and will be removed in version"
220
+ " 0.5.0. Make sure to pass one of the `scheduler.timesteps`."
221
+ )
222
+ step_index = timestep
223
+ else :
224
+ step_index = (self .timesteps == timestep ).nonzero ().item ()
219
225
sigma = self .sigmas [step_index ]
220
226
221
227
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
@@ -250,7 +256,14 @@ def add_noise(
250
256
sigmas = self .sigmas .to (original_samples .device )
251
257
schedule_timesteps = self .timesteps .to (original_samples .device )
252
258
timesteps = timesteps .to (original_samples .device )
253
- step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
259
+ if isinstance (timesteps , torch .IntTensor ) or isinstance (timesteps , torch .LongTensor ):
260
+ warnings .warn (
261
+ "Integer timesteps in `LMSDiscreteScheduler.add_noise()` are deprecated and will be removed in"
262
+ " version 0.5.0. Make sure to pass values from `scheduler.timesteps`."
263
+ )
264
+ step_indices = timesteps
265
+ else :
266
+ step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
254
267
255
268
sigma = sigmas [step_indices ].flatten ()
256
269
while len (sigma .shape ) < len (original_samples .shape ):
0 commit comments