Skip to content

Commit 12f0480

Browse files
santiviqueznatolambert
authored andcommitted
[Type hint] scheduling karras ve (#359)
1 parent 6e539e8 commit 12f0480

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

src/diffusers/schedulers/scheduling_karras_ve.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
from dataclasses import dataclass
17-
from typing import Tuple, Union
17+
from typing import Optional, Tuple, Union
1818

1919
import numpy as np
2020
import torch
@@ -72,13 +72,13 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
7272
@register_to_config
7373
def __init__(
7474
self,
75-
sigma_min=0.02,
76-
sigma_max=100,
77-
s_noise=1.007,
78-
s_churn=80,
79-
s_min=0.05,
80-
s_max=50,
81-
tensor_format="pt",
75+
sigma_min: float = 0.02,
76+
sigma_max: float = 100,
77+
s_noise: float = 1.007,
78+
s_churn: float = 80,
79+
s_min: float = 0.05,
80+
s_max: float = 50,
81+
tensor_format: str = "pt",
8282
):
8383
# setable values
8484
self.num_inference_steps = None
@@ -88,7 +88,7 @@ def __init__(
8888
self.tensor_format = tensor_format
8989
self.set_format(tensor_format=tensor_format)
9090

91-
def set_timesteps(self, num_inference_steps):
91+
def set_timesteps(self, num_inference_steps: int):
9292
self.num_inference_steps = num_inference_steps
9393
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
9494
self.schedule = [
@@ -99,7 +99,9 @@ def set_timesteps(self, num_inference_steps):
9999

100100
self.set_format(tensor_format=self.tensor_format)
101101

102-
def add_noise_to_input(self, sample, sigma, generator=None):
102+
def add_noise_to_input(
103+
self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None
104+
) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]:
103105
"""
104106
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
105107
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.

0 commit comments

Comments
 (0)