@@ -49,6 +49,24 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
49
49
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
50
50
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
51
51
differential equations." https://arxiv.org/abs/2011.13456
52
+
53
+ For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
54
+ Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
55
+ optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
56
+
57
+ Args:
58
+ sigma_min (`float`): minimum noise magnitude
59
+ sigma_max (`float`): maximum noise magnitude
60
+ s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
61
+ A reasonable range is [1.000, 1.011].
62
+ s_churn (`float`): the parameter controlling the overall amount of stochasticity.
63
+ A reasonable range is [0, 100].
64
+ s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
65
+ A reasonable range is [0, 10].
66
+ s_max (`float`): the end value of the sigma range where we add noise.
67
+ A reasonable range is [0.2, 80].
68
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
69
+
52
70
"""
53
71
54
72
@register_to_config
@@ -62,23 +80,6 @@ def __init__(
62
80
s_max : float = 50 ,
63
81
tensor_format : str = "pt" ,
64
82
):
65
- """
66
- For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
67
- Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
68
- optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
69
-
70
- Args:
71
- sigma_min (`float`): minimum noise magnitude
72
- sigma_max (`float`): maximum noise magnitude
73
- s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
74
- A reasonable range is [1.000, 1.011].
75
- s_churn (`float`): the parameter controlling the overall amount of stochasticity.
76
- A reasonable range is [0, 100].
77
- s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
78
- A reasonable range is [0, 10].
79
- s_max (`float`): the end value of the sigma range where we add noise.
80
- A reasonable range is [0.2, 80].
81
- """
82
83
# setable values
83
84
self .num_inference_steps = None
84
85
self .timesteps = None
@@ -88,6 +89,14 @@ def __init__(
88
89
self .set_format (tensor_format = tensor_format )
89
90
90
91
def set_timesteps (self , num_inference_steps : int ):
92
+ """
93
+ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
94
+
95
+ Args:
96
+ num_inference_steps (`int`):
97
+ the number of diffusion steps used when generating samples with a pre-trained model.
98
+
99
+ """
91
100
self .num_inference_steps = num_inference_steps
92
101
self .timesteps = np .arange (0 , self .num_inference_steps )[::- 1 ].copy ()
93
102
self .schedule = [
@@ -104,6 +113,8 @@ def add_noise_to_input(
104
113
"""
105
114
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
106
115
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
116
+
117
+ TODO Args:
107
118
"""
108
119
if self .s_min <= sigma <= self .s_max :
109
120
gamma = min (self .s_churn / self .num_inference_steps , 2 ** 0.5 - 1 )
@@ -125,6 +136,21 @@ def step(
125
136
sample_hat : Union [torch .FloatTensor , np .ndarray ],
126
137
return_dict : bool = True ,
127
138
) -> Union [KarrasVeOutput , Tuple ]:
139
+ """
140
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
141
+ process from the learned model outputs (most often the predicted noise).
142
+
143
+ Args:
144
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
145
+ sigma_hat (`float`): TODO
146
+ sigma_prev (`float`): TODO
147
+ sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
148
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
149
+
150
+ Returns:
151
+ KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
152
+
153
+ """
128
154
129
155
pred_original_sample = sample_hat + sigma_hat * model_output
130
156
derivative = (sample_hat - pred_original_sample ) / sigma_hat
@@ -145,7 +171,22 @@ def step_correct(
145
171
derivative : Union [torch .FloatTensor , np .ndarray ],
146
172
return_dict : bool = True ,
147
173
) -> Union [KarrasVeOutput , Tuple ]:
174
+ """
175
+ Correct the predicted sample based on the output model_output of the network. TODO complete description
176
+
177
+ Args:
178
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
179
+ sigma_hat (`float`): TODO
180
+ sigma_prev (`float`): TODO
181
+ sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
182
+ sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
183
+ derivative (`torch.FloatTensor` or `np.ndarray`): TODO
184
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
185
+
186
+ Returns:
187
+ prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
148
188
189
+ """
149
190
pred_original_sample = sample_prev + sigma_prev * model_output
150
191
derivative_corr = (sample_prev - pred_original_sample ) / sigma_prev
151
192
sample_prev = sample_hat + (sigma_prev - sigma_hat ) * (0.5 * derivative + 0.5 * derivative_corr )
0 commit comments