Skip to content

[Type hint] scheduling ddim #343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# and https://github.com/hojonathanho/diffusion

import math
from typing import Union
from typing import Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -52,15 +52,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
trained_betas=None,
timestep_values=None,
clip_sample=True,
set_alpha_to_one=True,
tensor_format="pt",
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
timestep_values: Optional[np.ndarray] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
tensor_format: str = "pt",
):

if beta_schedule == "linear":
Expand Down Expand Up @@ -100,7 +100,7 @@ def _get_variance(self, timestep, prev_timestep):

return variance

def set_timesteps(self, num_inference_steps, offset=0):
def set_timesteps(self, num_inference_steps: int, offset: int = 0):
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
Expand Down Expand Up @@ -176,7 +176,12 @@ def step(

return {"prev_sample": prev_sample}

def add_noise(self, original_samples, noise, timesteps):
def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
Expand Down