Skip to content

Commit beb2f2d

Browse files
committed
add self conditioning for elucidated ddpm
1 parent f0d59ac commit beb2f2d

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

denoising_diffusion_pytorch/elucidated_diffusion.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from math import sqrt
2+
from random import random
23
import torch
34
from torch import nn, einsum
45
import torch.nn.functional as F
@@ -52,7 +53,7 @@ def __init__(
5253
):
5354
super().__init__()
5455
assert net.learned_sinusoidal_cond
55-
assert not net.self_condition, 'not supported yet'
56+
self.self_condition = net.self_condition
5657

5758
self.net = net
5859

@@ -100,7 +101,7 @@ def c_noise(self, sigma):
100101
# preconditioned network output
101102
# equation (7) in the paper
102103

103-
def preconditioned_network_forward(self, noised_images, sigma, clamp = False):
104+
def preconditioned_network_forward(self, noised_images, sigma, self_cond = None, clamp = False):
104105
batch, device = noised_images.shape[0], noised_images.device
105106

106107
if isinstance(sigma, float):
@@ -110,7 +111,8 @@ def preconditioned_network_forward(self, noised_images, sigma, clamp = False):
110111

111112
net_out = self.net(
112113
self.c_in(padded_sigma) * noised_images,
113-
self.c_noise(sigma)
114+
self.c_noise(sigma),
115+
self_cond
114116
)
115117

116118
out = self.c_skip(padded_sigma) * noised_images + self.c_out(padded_sigma) * net_out
@@ -161,6 +163,10 @@ def sample(self, batch_size = 16, num_sample_steps = None, clamp = True):
161163

162164
images = init_sigma * torch.randn(shape, device = self.device)
163165

166+
# for self conditioning
167+
168+
x_start = None
169+
164170
# gradually denoise
165171

166172
for sigma, sigma_next, gamma in tqdm(sigmas_and_gammas, desc = 'sampling time step'):
@@ -171,19 +177,24 @@ def sample(self, batch_size = 16, num_sample_steps = None, clamp = True):
171177
sigma_hat = sigma + gamma * sigma
172178
images_hat = images + sqrt(sigma_hat ** 2 - sigma ** 2) * eps
173179

174-
model_output = self.preconditioned_network_forward(images_hat, sigma_hat, clamp = clamp)
180+
self_cond = x_start if self.self_condition else None
181+
182+
model_output = self.preconditioned_network_forward(images_hat, sigma_hat, self_cond, clamp = clamp)
175183
denoised_over_sigma = (images_hat - model_output) / sigma_hat
176184

177185
images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma
178186

179187
# second order correction, if not the last timestep
180188

181189
if sigma_next != 0:
182-
model_output_next = self.preconditioned_network_forward(images_next, sigma_next, clamp = clamp)
190+
self_cond = model_output if self.self_condition else None
191+
192+
model_output_next = self.preconditioned_network_forward(images_next, sigma_next, self_cond, clamp = clamp)
183193
denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next
184194
images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma)
185195

186196
images = images_next
197+
x_start = model_output
187198

188199
images = images.clamp(-1., 1.)
189200
return unnormalize_to_zero_to_one(images)
@@ -211,7 +222,15 @@ def forward(self, images):
211222

212223
noised_images = images + padded_sigmas * noise # alphas are 1. in the paper
213224

214-
denoised = self.preconditioned_network_forward(noised_images, sigmas)
225+
self_cond = None
226+
227+
if self.self_condition and random() < 0.5:
228+
# from hinton's group's bit diffusion paper
229+
with torch.no_grad():
230+
self_cond = self.preconditioned_network_forward(noised_images, sigmas)
231+
self_cond.detach_()
232+
233+
denoised = self.preconditioned_network_forward(noised_images, sigmas, self_cond)
215234

216235
losses = F.mse_loss(denoised, images, reduction = 'none')
217236
losses = reduce(losses, 'b ... -> b', 'mean')

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'denoising-diffusion-pytorch',
55
packages = find_packages(),
6-
version = '0.27.1',
6+
version = '0.27.2',
77
license='MIT',
88
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)