12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from typing import Tuple , Union
15
+ from typing import Optional , Tuple , Union
16
16
17
17
import numpy as np
18
18
import torch
@@ -27,13 +27,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
27
27
@register_to_config
28
28
def __init__ (
29
29
self ,
30
- num_train_timesteps = 1000 ,
31
- beta_start = 0.0001 ,
32
- beta_end = 0.02 ,
33
- beta_schedule = "linear" ,
34
- trained_betas = None ,
35
- timestep_values = None ,
36
- tensor_format = "pt" ,
30
+ num_train_timesteps : int = 1000 ,
31
+ beta_start : float = 0.0001 ,
32
+ beta_end : float = 0.02 ,
33
+ beta_schedule : str = "linear" ,
34
+ trained_betas : Optional [ np . ndarray ] = None ,
35
+ timestep_values : Optional [ np . ndarray ] = None ,
36
+ tensor_format : str = "pt" ,
37
37
):
38
38
"""
39
39
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
@@ -79,7 +79,7 @@ def lms_derivative(tau):
79
79
80
80
return integrated_coeff
81
81
82
- def set_timesteps (self , num_inference_steps ):
82
+ def set_timesteps (self , num_inference_steps : int ):
83
83
self .num_inference_steps = num_inference_steps
84
84
self .timesteps = np .linspace (self .num_train_timesteps - 1 , 0 , num_inference_steps , dtype = float )
85
85
@@ -127,7 +127,12 @@ def step(
127
127
128
128
return SchedulerOutput (prev_sample = prev_sample )
129
129
130
- def add_noise (self , original_samples , noise , timesteps ):
130
+ def add_noise (
131
+ self ,
132
+ original_samples : Union [torch .FloatTensor , np .ndarray ],
133
+ noise : Union [torch .FloatTensor , np .ndarray ],
134
+ timesteps : Union [torch .IntTensor , np .ndarray ],
135
+ ) -> Union [torch .FloatTensor , np .ndarray ]:
131
136
sigmas = self .match_shape (self .sigmas [timesteps ], noise )
132
137
noisy_samples = original_samples + noise * sigmas
133
138
0 commit comments