Skip to content

Adding pred_original_sample to SchedulerOutput for some samplers #614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 22, 2022
32 changes: 26 additions & 6 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,33 @@

import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin


@dataclass
class DDIMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.

Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
Expand Down Expand Up @@ -179,7 +199,7 @@ def step(
use_clipped_model_output: bool = False,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[DDIMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Expand All @@ -192,11 +212,11 @@ def step(
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class

Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.

"""
Expand Down Expand Up @@ -261,7 +281,7 @@ def step(
if not return_dict:
return (prev_sample,)

return SchedulerOutput(prev_sample=prev_sample)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

def add_noise(
self,
Expand Down
32 changes: 26 additions & 6 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,33 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin


@dataclass
class DDPMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.

Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
Expand Down Expand Up @@ -177,7 +197,7 @@ def step(
predict_epsilon=True,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[DDPMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Expand All @@ -190,11 +210,11 @@ def step(
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class

Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.

"""
Expand Down Expand Up @@ -242,7 +262,7 @@ def step(
if not return_dict:
return (pred_prev_sample,)

return SchedulerOutput(prev_sample=pred_prev_sample)
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)

def add_noise(
self,
Expand Down
16 changes: 12 additions & 4 deletions src/diffusers/schedulers/scheduling_karras_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@ class KarrasVeOutput(BaseOutput):
denoising loop.
derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Derivative of predicted original image sample (x_0).
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: torch.FloatTensor
derivative: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None


class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
Expand Down Expand Up @@ -153,7 +157,7 @@ def step(
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class

KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
Returns:
Expand All @@ -170,7 +174,9 @@ def step(
if not return_dict:
return (sample_prev, derivative)

return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
return KarrasVeOutput(
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
)

def step_correct(
self,
Expand All @@ -192,7 +198,7 @@ def step_correct(
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class

Returns:
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
Expand All @@ -205,7 +211,9 @@ def step_correct(
if not return_dict:
return (sample_prev, derivative)

return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
return KarrasVeOutput(
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
)

def add_noise(self, original_samples, noise, timesteps):
raise NotImplementedError()
34 changes: 27 additions & 7 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
Expand All @@ -20,7 +21,26 @@
from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin


@dataclass
class LMSDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.

Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None


class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
Expand Down Expand Up @@ -133,7 +153,7 @@ def step(
sample: Union[torch.FloatTensor, np.ndarray],
order: int = 4,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Expand All @@ -144,12 +164,12 @@ def step(
sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class

Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.

"""
sigma = self.sigmas[timestep]
Expand All @@ -175,7 +195,7 @@ def step(
if not return_dict:
return (prev_sample,)

return SchedulerOutput(prev_sample=prev_sample)
return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

def add_noise(
self,
Expand Down