Skip to content

Commit 5eff912

Browse files
santiviqueznatolambert
authored andcommitted
[Type hint] scheduling lms discrete (#360)
* [Type hint] scheduling karras ve * [Type hint] scheduling lms discrete
1 parent 12f0480 commit 5eff912

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Tuple, Union
15+
from typing import Optional, Tuple, Union
1616

1717
import numpy as np
1818
import torch
@@ -27,13 +27,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
2727
@register_to_config
2828
def __init__(
2929
self,
30-
num_train_timesteps=1000,
31-
beta_start=0.0001,
32-
beta_end=0.02,
33-
beta_schedule="linear",
34-
trained_betas=None,
35-
timestep_values=None,
36-
tensor_format="pt",
30+
num_train_timesteps: int = 1000,
31+
beta_start: float = 0.0001,
32+
beta_end: float = 0.02,
33+
beta_schedule: str = "linear",
34+
trained_betas: Optional[np.ndarray] = None,
35+
timestep_values: Optional[np.ndarray] = None,
36+
tensor_format: str = "pt",
3737
):
3838
"""
3939
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
@@ -79,7 +79,7 @@ def lms_derivative(tau):
7979

8080
return integrated_coeff
8181

82-
def set_timesteps(self, num_inference_steps):
82+
def set_timesteps(self, num_inference_steps: int):
8383
self.num_inference_steps = num_inference_steps
8484
self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
8585

@@ -127,7 +127,12 @@ def step(
127127

128128
return SchedulerOutput(prev_sample=prev_sample)
129129

130-
def add_noise(self, original_samples, noise, timesteps):
130+
def add_noise(
131+
self,
132+
original_samples: Union[torch.FloatTensor, np.ndarray],
133+
noise: Union[torch.FloatTensor, np.ndarray],
134+
timesteps: Union[torch.IntTensor, np.ndarray],
135+
) -> Union[torch.FloatTensor, np.ndarray]:
131136
sigmas = self.match_shape(self.sigmas[timesteps], noise)
132137
noisy_samples = original_samples + noise * sigmas
133138

0 commit comments

Comments
 (0)