14
14
15
15
16
16
from dataclasses import dataclass
17
- from typing import Tuple , Union
17
+ from typing import Optional , Tuple , Union
18
18
19
19
import numpy as np
20
20
import torch
@@ -72,13 +72,13 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
72
72
@register_to_config
73
73
def __init__ (
74
74
self ,
75
- sigma_min = 0.02 ,
76
- sigma_max = 100 ,
77
- s_noise = 1.007 ,
78
- s_churn = 80 ,
79
- s_min = 0.05 ,
80
- s_max = 50 ,
81
- tensor_format = "pt" ,
75
+ sigma_min : float = 0.02 ,
76
+ sigma_max : float = 100 ,
77
+ s_noise : float = 1.007 ,
78
+ s_churn : float = 80 ,
79
+ s_min : float = 0.05 ,
80
+ s_max : float = 50 ,
81
+ tensor_format : str = "pt" ,
82
82
):
83
83
# setable values
84
84
self .num_inference_steps = None
@@ -88,7 +88,7 @@ def __init__(
88
88
self .tensor_format = tensor_format
89
89
self .set_format (tensor_format = tensor_format )
90
90
91
- def set_timesteps (self , num_inference_steps ):
91
+ def set_timesteps (self , num_inference_steps : int ):
92
92
self .num_inference_steps = num_inference_steps
93
93
self .timesteps = np .arange (0 , self .num_inference_steps )[::- 1 ].copy ()
94
94
self .schedule = [
@@ -99,7 +99,9 @@ def set_timesteps(self, num_inference_steps):
99
99
100
100
self .set_format (tensor_format = self .tensor_format )
101
101
102
- def add_noise_to_input (self , sample , sigma , generator = None ):
102
+ def add_noise_to_input (
103
+ self , sample : Union [torch .FloatTensor , np .ndarray ], sigma : float , generator : Optional [torch .Generator ] = None
104
+ ) -> Tuple [Union [torch .FloatTensor , np .ndarray ], float ]:
103
105
"""
104
106
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
105
107
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
0 commit comments