Skip to content

[Type hint] scheduling lms discrete #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions src/diffusers/schedulers/scheduling_karras_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


from dataclasses import dataclass
from typing import Tuple, Union
from typing import Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -54,13 +54,13 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(
self,
sigma_min=0.02,
sigma_max=100,
s_noise=1.007,
s_churn=80,
s_min=0.05,
s_max=50,
tensor_format="pt",
sigma_min: float = 0.02,
sigma_max: float = 100,
s_noise: float = 1.007,
s_churn: float = 80,
s_min: float = 0.05,
s_max: float = 50,
tensor_format: str = "pt",
):
"""
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Expand All @@ -87,7 +87,7 @@ def __init__(
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)

def set_timesteps(self, num_inference_steps):
def set_timesteps(self, num_inference_steps: int):
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.schedule = [
Expand All @@ -98,7 +98,9 @@ def set_timesteps(self, num_inference_steps):

self.set_format(tensor_format=self.tensor_format)

def add_noise_to_input(self, sample, sigma, generator=None):
def add_noise_to_input(
self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None
) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]:
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
Expand Down
25 changes: 15 additions & 10 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple, Union
from typing import Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -27,13 +27,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
trained_betas=None,
timestep_values=None,
tensor_format="pt",
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
timestep_values: Optional[np.ndarray] = None,
tensor_format: str = "pt",
):
"""
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
Expand Down Expand Up @@ -79,7 +79,7 @@ def lms_derivative(tau):

return integrated_coeff

def set_timesteps(self, num_inference_steps):
def set_timesteps(self, num_inference_steps: int):
self.num_inference_steps = num_inference_steps
self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)

Expand Down Expand Up @@ -127,7 +127,12 @@ def step(

return SchedulerOutput(prev_sample=prev_sample)

def add_noise(self, original_samples, noise, timesteps):
def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
sigmas = self.match_shape(self.sigmas[timesteps], noise)
noisy_samples = original_samples + noise * sigmas

Expand Down