diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index a5369b1603c6..0613ffd41d0e 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -17,13 +17,33 @@ import math import warnings +from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -179,7 +199,7 @@ def step( use_clipped_model_output: bool = False, generator=None, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -192,11 +212,11 @@ def step( eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): TODO generator: random number generator. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -261,7 +281,7 @@ def step( if not return_dict: return (prev_sample,) - return SchedulerOutput(prev_sample=prev_sample) + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index d008b84da6e7..440b880385d4 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -15,13 +15,33 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math +from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class DDPMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -177,7 +197,7 @@ def step( predict_epsilon=True, generator=None, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[DDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -190,11 +210,11 @@ def step( predict_epsilon (`bool`): optional flag to use when model predicts the samples directly instead of the noise, epsilon. generator: random number generator. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -242,7 +262,7 @@ def step( if not return_dict: return (pred_prev_sample,) - return SchedulerOutput(prev_sample=pred_prev_sample) + return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index caf7625fb683..6e66bed400f4 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -35,10 +35,14 @@ class KarrasVeOutput(BaseOutput): denoising loop. derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Derivative of predicted original image sample (x_0). + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor derivative: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None class KarrasVeScheduler(SchedulerMixin, ConfigMixin): @@ -153,7 +157,7 @@ def step( sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). Returns: @@ -170,7 +174,9 @@ def step( if not return_dict: return (sample_prev, derivative) - return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative) + return KarrasVeOutput( + prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample + ) def step_correct( self, @@ -192,7 +198,7 @@ def step_correct( sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO derivative (`torch.FloatTensor` or `np.ndarray`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class Returns: prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO @@ -205,7 +211,9 @@ def step_correct( if not return_dict: return (sample_prev, derivative) - return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative) + return KarrasVeOutput( + prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample + ) def add_noise(self, original_samples, noise, timesteps): raise NotImplementedError() diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 5857ae70a856..1dd6dbda1e19 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np @@ -20,7 +21,26 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class LMSDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): @@ -133,7 +153,7 @@ def step( sample: Union[torch.FloatTensor, np.ndarray], order: int = 4, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[LMSDiscreteSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -144,12 +164,12 @@ def step( sample (`torch.FloatTensor` or `np.ndarray`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. """ sigma = self.sigmas[timestep] @@ -175,7 +195,7 @@ def step( if not return_dict: return (prev_sample,) - return SchedulerOutput(prev_sample=prev_sample) + return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def add_noise( self,