1
1
from math import sqrt
2
+ from random import random
2
3
import torch
3
4
from torch import nn , einsum
4
5
import torch .nn .functional as F
@@ -52,7 +53,7 @@ def __init__(
52
53
):
53
54
super ().__init__ ()
54
55
assert net .learned_sinusoidal_cond
55
- assert not net .self_condition , 'not supported yet'
56
+ self . self_condition = net .self_condition
56
57
57
58
self .net = net
58
59
@@ -100,7 +101,7 @@ def c_noise(self, sigma):
100
101
# preconditioned network output
101
102
# equation (7) in the paper
102
103
103
- def preconditioned_network_forward (self , noised_images , sigma , clamp = False ):
104
+ def preconditioned_network_forward (self , noised_images , sigma , self_cond = None , clamp = False ):
104
105
batch , device = noised_images .shape [0 ], noised_images .device
105
106
106
107
if isinstance (sigma , float ):
@@ -110,7 +111,8 @@ def preconditioned_network_forward(self, noised_images, sigma, clamp = False):
110
111
111
112
net_out = self .net (
112
113
self .c_in (padded_sigma ) * noised_images ,
113
- self .c_noise (sigma )
114
+ self .c_noise (sigma ),
115
+ self_cond
114
116
)
115
117
116
118
out = self .c_skip (padded_sigma ) * noised_images + self .c_out (padded_sigma ) * net_out
@@ -161,6 +163,10 @@ def sample(self, batch_size = 16, num_sample_steps = None, clamp = True):
161
163
162
164
images = init_sigma * torch .randn (shape , device = self .device )
163
165
166
+ # for self conditioning
167
+
168
+ x_start = None
169
+
164
170
# gradually denoise
165
171
166
172
for sigma , sigma_next , gamma in tqdm (sigmas_and_gammas , desc = 'sampling time step' ):
@@ -171,19 +177,24 @@ def sample(self, batch_size = 16, num_sample_steps = None, clamp = True):
171
177
sigma_hat = sigma + gamma * sigma
172
178
images_hat = images + sqrt (sigma_hat ** 2 - sigma ** 2 ) * eps
173
179
174
- model_output = self .preconditioned_network_forward (images_hat , sigma_hat , clamp = clamp )
180
+ self_cond = x_start if self .self_condition else None
181
+
182
+ model_output = self .preconditioned_network_forward (images_hat , sigma_hat , self_cond , clamp = clamp )
175
183
denoised_over_sigma = (images_hat - model_output ) / sigma_hat
176
184
177
185
images_next = images_hat + (sigma_next - sigma_hat ) * denoised_over_sigma
178
186
179
187
# second order correction, if not the last timestep
180
188
181
189
if sigma_next != 0 :
182
- model_output_next = self .preconditioned_network_forward (images_next , sigma_next , clamp = clamp )
190
+ self_cond = model_output if self .self_condition else None
191
+
192
+ model_output_next = self .preconditioned_network_forward (images_next , sigma_next , self_cond , clamp = clamp )
183
193
denoised_prime_over_sigma = (images_next - model_output_next ) / sigma_next
184
194
images_next = images_hat + 0.5 * (sigma_next - sigma_hat ) * (denoised_over_sigma + denoised_prime_over_sigma )
185
195
186
196
images = images_next
197
+ x_start = model_output
187
198
188
199
images = images .clamp (- 1. , 1. )
189
200
return unnormalize_to_zero_to_one (images )
@@ -211,7 +222,15 @@ def forward(self, images):
211
222
212
223
noised_images = images + padded_sigmas * noise # alphas are 1. in the paper
213
224
214
- denoised = self .preconditioned_network_forward (noised_images , sigmas )
225
+ self_cond = None
226
+
227
+ if self .self_condition and random () < 0.5 :
228
+ # from hinton's group's bit diffusion paper
229
+ with torch .no_grad ():
230
+ self_cond = self .preconditioned_network_forward (noised_images , sigmas )
231
+ self_cond .detach_ ()
232
+
233
+ denoised = self .preconditioned_network_forward (noised_images , sigmas , self_cond )
215
234
216
235
losses = F .mse_loss (denoised , images , reduction = 'none' )
217
236
losses = reduce (losses , 'b ... -> b' , 'mean' )
0 commit comments