From 189f69420c752f5db630417f444344b68dec0e0b Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 6 Oct 2022 12:55:34 +0200 Subject: [PATCH 1/2] Add back-compatibility to LMS timesteps --- .../schedulers/scheduling_lms_discrete.py | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index a55811a0629f..13345b64efb8 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -202,11 +202,6 @@ def step( When returning a tuple, the first element is the sample tensor. """ - if not isinstance(timestep, float) and not isinstance(timestep, torch.FloatTensor): - warnings.warn( - f"`LMSDiscreteScheduler` timesteps must be `float` or `torch.FloatTensor`, not {type(timestep)}. " - "Make sure to pass one of the `scheduler.timesteps`" - ) if not self.is_scale_input_called: warnings.warn( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " @@ -215,7 +210,18 @@ def step( if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero().item() + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + warnings.warn( + f"Integer timesteps in `LMSDiscreteScheduler.step()` are deprecated and will be removed in version" + f" 0.5.0. Make sure to pass one of the `scheduler.timesteps`." + ) + step_index = timestep + else: + step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise @@ -250,7 +256,14 @@ def add_noise( sigmas = self.sigmas.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor): + warnings.warn( + f"Integer timesteps in `LMSDiscreteScheduler.add_noise()` are deprecated and will be removed in" + f" version 0.5.0. Make sure to pass values from `scheduler.timesteps`." + ) + step_indices = timesteps + else: + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): From c437eed103166415872fbc63dd69d623f5eb828b Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 6 Oct 2022 12:57:55 +0200 Subject: [PATCH 2/2] style --- src/diffusers/schedulers/scheduling_lms_discrete.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 13345b64efb8..a4c1d74bfe50 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -216,8 +216,8 @@ def step( or isinstance(timestep, torch.LongTensor) ): warnings.warn( - f"Integer timesteps in `LMSDiscreteScheduler.step()` are deprecated and will be removed in version" - f" 0.5.0. Make sure to pass one of the `scheduler.timesteps`." + "Integer timesteps in `LMSDiscreteScheduler.step()` are deprecated and will be removed in version" + " 0.5.0. Make sure to pass one of the `scheduler.timesteps`." ) step_index = timestep else: @@ -258,8 +258,8 @@ def add_noise( timesteps = timesteps.to(original_samples.device) if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor): warnings.warn( - f"Integer timesteps in `LMSDiscreteScheduler.add_noise()` are deprecated and will be removed in" - f" version 0.5.0. Make sure to pass values from `scheduler.timesteps`." + "Integer timesteps in `LMSDiscreteScheduler.add_noise()` are deprecated and will be removed in" + " version 0.5.0. Make sure to pass values from `scheduler.timesteps`." ) step_indices = timesteps else: