Skip to content

Commit 192f8b9

Browse files
committed
fix a few autocast warnings, add new technique for cfg
1 parent c166739 commit 192f8b9

File tree

5 files changed

+102
-9
lines changed

5 files changed

+102
-9
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,3 +947,12 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
947947
note = {under review}
948948
}
949949
```
950+
951+
```bibtex
952+
@inproceedings{Sadat2024EliminatingOA,
953+
title = {Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion Models},
954+
author = {Seyedmorteza Sadat and Otmar Hilliges and Romann M. Weber},
955+
year = {2024},
956+
url = {https://api.semanticscholar.org/CorpusID:273098845}
957+
}
958+
```

imagen_pytorch/elucidated_imagen.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import torch.nn.functional as F
1111
from torch import nn
12-
from torch.cuda.amp import autocast
12+
from torch.amp import autocast
1313
from torch.nn.parallel import DistributedDataParallel
1414
import torchvision.transforms as T
1515

@@ -565,6 +565,8 @@ def sample(
565565
video_frames = None,
566566
batch_size = 1,
567567
cond_scale = 1.,
568+
cfg_remove_parallel_component = True,
569+
cfg_keep_parallel_frac = 0.,
568570
lowres_sample_noise_level = None,
569571
start_at_unet_number = 1,
570572
start_image_or_video = None,
@@ -583,7 +585,7 @@ def sample(
583585
if exists(texts) and not exists(text_embeds) and not self.unconditional:
584586
assert all([*map(len, texts)]), 'text cannot be empty'
585587

586-
with autocast(enabled = False):
588+
with autocast('cuda', enabled = False):
587589
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
588590

589591
text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
@@ -724,6 +726,8 @@ def sample(
724726
sigma_min = unet_sigma_min,
725727
sigma_max = unet_sigma_max,
726728
cond_scale = unet_cond_scale,
729+
remove_parallel_component = cfg_remove_parallel_component,
730+
keep_parallel_frac = cfg_keep_parallel_frac,
727731
lowres_cond_img = lowres_cond_img,
728732
lowres_noise_times = lowres_noise_times,
729733
dynamic_threshold = dynamic_threshold,
@@ -811,7 +815,7 @@ def forward(
811815
assert all([*map(len, texts)]), 'text cannot be empty'
812816
assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'
813817

814-
with autocast(enabled = False):
818+
with autocast('cuda', enabled = False):
815819
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
816820

817821
text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))

imagen_pytorch/imagen_pytorch.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch.nn.functional as F
1212
from torch.nn.parallel import DistributedDataParallel
1313
from torch import nn, einsum
14-
from torch.cuda.amp import autocast
14+
from torch.amp import autocast
1515
from torch.special import expm1
1616
import torchvision.transforms as T
1717

@@ -187,6 +187,15 @@ def safe_get_tuple_index(tup, index, default = None):
187187
return default
188188
return tup[index]
189189

190+
def pack_one_with_inverse(x, pattern):
191+
packed, packed_shape = pack([x], pattern)
192+
193+
def inverse(x, inverse_pattern = None):
194+
inverse_pattern = default(inverse_pattern, pattern)
195+
return unpack(x, packed_shape, inverse_pattern)[0]
196+
197+
return packed, inverse
198+
190199
# image normalization functions
191200
# ddpms expect images to be in the range of -1 to 1
192201

@@ -206,6 +215,21 @@ def prob_mask_like(shape, prob, device):
206215
else:
207216
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
208217

218+
# for improved cfg, getting parallel and orthogonal components of cfg update
219+
220+
def project(x, y):
221+
x, inverse = pack_one_with_inverse(x, 'b *')
222+
y, _ = pack_one_with_inverse(y, 'b *')
223+
224+
dtype = x.dtype
225+
x, y = x.double(), y.double()
226+
unit = F.normalize(y, dim = -1)
227+
228+
parallel = (x * unit).sum(dim = -1, keepdim = True) * unit
229+
orthogonal = x - parallel
230+
231+
return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype)
232+
209233
# gaussian diffusion with continuous time helper functions and classes
210234
# large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py
211235

@@ -1511,6 +1535,8 @@ def forward_with_cond_scale(
15111535
self,
15121536
*args,
15131537
cond_scale = 1.,
1538+
remove_parallel_component = True,
1539+
keep_parallel_frac = 0.,
15141540
**kwargs
15151541
):
15161542
logits = self.forward(*args, **kwargs)
@@ -1519,7 +1545,14 @@ def forward_with_cond_scale(
15191545
return logits
15201546

15211547
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
1522-
return null_logits + (logits - null_logits) * cond_scale
1548+
1549+
update = (logits - null_logits)
1550+
1551+
if remove_parallel_component:
1552+
parallel, orthogonal = project(update, logits)
1553+
update = orthogonal + parallel * keep_parallel_frac
1554+
1555+
return logits + update * (cond_scale - 1)
15231556

15241557
def forward(
15251558
self,
@@ -2055,6 +2088,8 @@ def p_mean_variance(
20552088
self_cond = None,
20562089
lowres_noise_times = None,
20572090
cond_scale = 1.,
2091+
cfg_remove_parallel_component = True,
2092+
cfg_keep_parallel_frac = 0.,
20582093
model_output = None,
20592094
t_next = None,
20602095
pred_objective = 'noise',
@@ -2076,6 +2111,8 @@ def p_mean_variance(
20762111
text_mask = text_mask,
20772112
cond_images = cond_images,
20782113
cond_scale = cond_scale,
2114+
remove_parallel_component = cfg_remove_parallel_component,
2115+
keep_parallel_frac = cfg_keep_parallel_frac,
20792116
lowres_cond_img = lowres_cond_img,
20802117
self_cond = self_cond,
20812118
lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_noise_times),
@@ -2124,6 +2161,8 @@ def p_sample(
21242161
cond_video_frames = None,
21252162
post_cond_video_frames = None,
21262163
cond_scale = 1.,
2164+
cfg_remove_parallel_component = True,
2165+
cfg_keep_parallel_frac = 0.,
21272166
self_cond = None,
21282167
lowres_cond_img = None,
21292168
lowres_noise_times = None,
@@ -2149,6 +2188,8 @@ def p_sample(
21492188
text_mask = text_mask,
21502189
cond_images = cond_images,
21512190
cond_scale = cond_scale,
2191+
cfg_remove_parallel_component = cfg_remove_parallel_component,
2192+
cfg_keep_parallel_frac = cfg_keep_parallel_frac,
21522193
lowres_cond_img = lowres_cond_img,
21532194
self_cond = self_cond,
21542195
lowres_noise_times = lowres_noise_times,
@@ -2185,6 +2226,8 @@ def p_sample_loop(
21852226
init_images = None,
21862227
skip_steps = None,
21872228
cond_scale = 1,
2229+
cfg_remove_parallel_component = False,
2230+
cfg_keep_parallel_frac = 0.,
21882231
pred_objective = 'noise',
21892232
dynamic_threshold = True,
21902233
use_tqdm = True
@@ -2260,6 +2303,8 @@ def p_sample_loop(
22602303
text_mask = text_mask,
22612304
cond_images = cond_images,
22622305
cond_scale = cond_scale,
2306+
cfg_remove_parallel_component = cfg_remove_parallel_component,
2307+
cfg_keep_parallel_frac = cfg_keep_parallel_frac,
22632308
self_cond = self_cond,
22642309
lowres_cond_img = lowres_cond_img,
22652310
lowres_noise_times = lowres_noise_times,
@@ -2308,6 +2353,8 @@ def sample(
23082353
skip_steps = None,
23092354
batch_size = 1,
23102355
cond_scale = 1.,
2356+
cfg_remove_parallel_component = True,
2357+
cfg_keep_parallel_frac = 0.,
23112358
lowres_sample_noise_level = None,
23122359
start_at_unet_number = 1,
23132360
start_image_or_video = None,
@@ -2326,7 +2373,7 @@ def sample(
23262373
if exists(texts) and not exists(text_embeds) and not self.unconditional:
23272374
assert all([*map(len, texts)]), 'text cannot be empty'
23282375

2329-
with autocast(enabled = False):
2376+
with autocast('cuda', enabled = False):
23302377
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
23312378

23322379
text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
@@ -2469,6 +2516,8 @@ def sample(
24692516
init_images = unet_init_images,
24702517
skip_steps = unet_skip_steps,
24712518
cond_scale = unet_cond_scale,
2519+
cfg_remove_parallel_component = cfg_remove_parallel_component,
2520+
cfg_keep_parallel_frac = cfg_keep_parallel_frac,
24722521
lowres_cond_img = lowres_cond_img,
24732522
lowres_noise_times = lowres_noise_times,
24742523
noise_scheduler = noise_scheduler,
@@ -2695,7 +2744,7 @@ def forward(
26952744
assert all([*map(len, texts)]), 'text cannot be empty'
26962745
assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'
26972746

2698-
with autocast(enabled = False):
2747+
with autocast('cuda', enabled = False):
26992748
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
27002749

27012750
text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))

imagen_pytorch/imagen_video.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,15 @@ def pad_tuple_to_length(t, length, fillvalue = None):
9595
return t
9696
return (*t, *((fillvalue,) * remain_length))
9797

98+
def pack_one_with_inverse(x, pattern):
99+
packed, packed_shape = pack([x], pattern)
100+
101+
def inverse(x, inverse_pattern = None):
102+
inverse_pattern = default(inverse_pattern, pattern)
103+
return unpack(x, packed_shape, inverse_pattern)[0]
104+
105+
return packed, inverse
106+
98107
# helper classes
99108

100109
class Identity(nn.Module):
@@ -131,6 +140,19 @@ def masked_mean(t, *, dim, mask = None):
131140

132141
return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
133142

143+
def project(x, y):
144+
x, inverse = pack_one_with_inverse(x, 'b *')
145+
y, _ = pack_one_with_inverse(y, 'b *')
146+
147+
dtype = x.dtype
148+
x, y = x.double(), y.double()
149+
unit = F.normalize(y, dim = -1)
150+
151+
parallel = (x * unit).sum(dim = -1, keepdim = True) * unit
152+
orthogonal = x - parallel
153+
154+
return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype)
155+
134156
def resize_video_to(
135157
video,
136158
target_image_size,
@@ -1637,6 +1659,8 @@ def forward_with_cond_scale(
16371659
self,
16381660
*args,
16391661
cond_scale = 1.,
1662+
remove_parallel_component = False,
1663+
keep_parallel_frac = 0.,
16401664
**kwargs
16411665
):
16421666
logits = self.forward(*args, **kwargs)
@@ -1645,7 +1669,14 @@ def forward_with_cond_scale(
16451669
return logits
16461670

16471671
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
1648-
return null_logits + (logits - null_logits) * cond_scale
1672+
1673+
update = (logits - null_logits)
1674+
1675+
if remove_parallel_component:
1676+
parallel, orthogonal = project(update, logits)
1677+
update = orthogonal + parallel * keep_parallel_frac
1678+
1679+
return logits + update * (cond_scale - 1)
16491680

16501681
def forward(
16511682
self,

imagen_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '2.0.0'
1+
__version__ = '2.1.0'

0 commit comments

Comments
 (0)