11
11
import torch .nn .functional as F
12
12
from torch .nn .parallel import DistributedDataParallel
13
13
from torch import nn , einsum
14
- from torch .cuda . amp import autocast
14
+ from torch .amp import autocast
15
15
from torch .special import expm1
16
16
import torchvision .transforms as T
17
17
@@ -187,6 +187,15 @@ def safe_get_tuple_index(tup, index, default = None):
187
187
return default
188
188
return tup [index ]
189
189
190
+ def pack_one_with_inverse (x , pattern ):
191
+ packed , packed_shape = pack ([x ], pattern )
192
+
193
+ def inverse (x , inverse_pattern = None ):
194
+ inverse_pattern = default (inverse_pattern , pattern )
195
+ return unpack (x , packed_shape , inverse_pattern )[0 ]
196
+
197
+ return packed , inverse
198
+
190
199
# image normalization functions
191
200
# ddpms expect images to be in the range of -1 to 1
192
201
@@ -206,6 +215,21 @@ def prob_mask_like(shape, prob, device):
206
215
else :
207
216
return torch .zeros (shape , device = device ).float ().uniform_ (0 , 1 ) < prob
208
217
218
+ # for improved cfg, getting parallel and orthogonal components of cfg update
219
+
220
+ def project (x , y ):
221
+ x , inverse = pack_one_with_inverse (x , 'b *' )
222
+ y , _ = pack_one_with_inverse (y , 'b *' )
223
+
224
+ dtype = x .dtype
225
+ x , y = x .double (), y .double ()
226
+ unit = F .normalize (y , dim = - 1 )
227
+
228
+ parallel = (x * unit ).sum (dim = - 1 , keepdim = True ) * unit
229
+ orthogonal = x - parallel
230
+
231
+ return inverse (parallel ).to (dtype ), inverse (orthogonal ).to (dtype )
232
+
209
233
# gaussian diffusion with continuous time helper functions and classes
210
234
# large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py
211
235
@@ -1511,6 +1535,8 @@ def forward_with_cond_scale(
1511
1535
self ,
1512
1536
* args ,
1513
1537
cond_scale = 1. ,
1538
+ remove_parallel_component = True ,
1539
+ keep_parallel_frac = 0. ,
1514
1540
** kwargs
1515
1541
):
1516
1542
logits = self .forward (* args , ** kwargs )
@@ -1519,7 +1545,14 @@ def forward_with_cond_scale(
1519
1545
return logits
1520
1546
1521
1547
null_logits = self .forward (* args , cond_drop_prob = 1. , ** kwargs )
1522
- return null_logits + (logits - null_logits ) * cond_scale
1548
+
1549
+ update = (logits - null_logits )
1550
+
1551
+ if remove_parallel_component :
1552
+ parallel , orthogonal = project (update , logits )
1553
+ update = orthogonal + parallel * keep_parallel_frac
1554
+
1555
+ return logits + update * (cond_scale - 1 )
1523
1556
1524
1557
def forward (
1525
1558
self ,
@@ -2055,6 +2088,8 @@ def p_mean_variance(
2055
2088
self_cond = None ,
2056
2089
lowres_noise_times = None ,
2057
2090
cond_scale = 1. ,
2091
+ cfg_remove_parallel_component = True ,
2092
+ cfg_keep_parallel_frac = 0. ,
2058
2093
model_output = None ,
2059
2094
t_next = None ,
2060
2095
pred_objective = 'noise' ,
@@ -2076,6 +2111,8 @@ def p_mean_variance(
2076
2111
text_mask = text_mask ,
2077
2112
cond_images = cond_images ,
2078
2113
cond_scale = cond_scale ,
2114
+ remove_parallel_component = cfg_remove_parallel_component ,
2115
+ keep_parallel_frac = cfg_keep_parallel_frac ,
2079
2116
lowres_cond_img = lowres_cond_img ,
2080
2117
self_cond = self_cond ,
2081
2118
lowres_noise_times = self .lowres_noise_schedule .get_condition (lowres_noise_times ),
@@ -2124,6 +2161,8 @@ def p_sample(
2124
2161
cond_video_frames = None ,
2125
2162
post_cond_video_frames = None ,
2126
2163
cond_scale = 1. ,
2164
+ cfg_remove_parallel_component = True ,
2165
+ cfg_keep_parallel_frac = 0. ,
2127
2166
self_cond = None ,
2128
2167
lowres_cond_img = None ,
2129
2168
lowres_noise_times = None ,
@@ -2149,6 +2188,8 @@ def p_sample(
2149
2188
text_mask = text_mask ,
2150
2189
cond_images = cond_images ,
2151
2190
cond_scale = cond_scale ,
2191
+ cfg_remove_parallel_component = cfg_remove_parallel_component ,
2192
+ cfg_keep_parallel_frac = cfg_keep_parallel_frac ,
2152
2193
lowres_cond_img = lowres_cond_img ,
2153
2194
self_cond = self_cond ,
2154
2195
lowres_noise_times = lowres_noise_times ,
@@ -2185,6 +2226,8 @@ def p_sample_loop(
2185
2226
init_images = None ,
2186
2227
skip_steps = None ,
2187
2228
cond_scale = 1 ,
2229
+ cfg_remove_parallel_component = False ,
2230
+ cfg_keep_parallel_frac = 0. ,
2188
2231
pred_objective = 'noise' ,
2189
2232
dynamic_threshold = True ,
2190
2233
use_tqdm = True
@@ -2260,6 +2303,8 @@ def p_sample_loop(
2260
2303
text_mask = text_mask ,
2261
2304
cond_images = cond_images ,
2262
2305
cond_scale = cond_scale ,
2306
+ cfg_remove_parallel_component = cfg_remove_parallel_component ,
2307
+ cfg_keep_parallel_frac = cfg_keep_parallel_frac ,
2263
2308
self_cond = self_cond ,
2264
2309
lowres_cond_img = lowres_cond_img ,
2265
2310
lowres_noise_times = lowres_noise_times ,
@@ -2308,6 +2353,8 @@ def sample(
2308
2353
skip_steps = None ,
2309
2354
batch_size = 1 ,
2310
2355
cond_scale = 1. ,
2356
+ cfg_remove_parallel_component = True ,
2357
+ cfg_keep_parallel_frac = 0. ,
2311
2358
lowres_sample_noise_level = None ,
2312
2359
start_at_unet_number = 1 ,
2313
2360
start_image_or_video = None ,
@@ -2326,7 +2373,7 @@ def sample(
2326
2373
if exists (texts ) and not exists (text_embeds ) and not self .unconditional :
2327
2374
assert all ([* map (len , texts )]), 'text cannot be empty'
2328
2375
2329
- with autocast (enabled = False ):
2376
+ with autocast ('cuda' , enabled = False ):
2330
2377
text_embeds , text_masks = self .encode_text (texts , return_attn_mask = True )
2331
2378
2332
2379
text_embeds , text_masks = map (lambda t : t .to (device ), (text_embeds , text_masks ))
@@ -2469,6 +2516,8 @@ def sample(
2469
2516
init_images = unet_init_images ,
2470
2517
skip_steps = unet_skip_steps ,
2471
2518
cond_scale = unet_cond_scale ,
2519
+ cfg_remove_parallel_component = cfg_remove_parallel_component ,
2520
+ cfg_keep_parallel_frac = cfg_keep_parallel_frac ,
2472
2521
lowres_cond_img = lowres_cond_img ,
2473
2522
lowres_noise_times = lowres_noise_times ,
2474
2523
noise_scheduler = noise_scheduler ,
@@ -2695,7 +2744,7 @@ def forward(
2695
2744
assert all ([* map (len , texts )]), 'text cannot be empty'
2696
2745
assert len (texts ) == len (images ), 'number of text captions does not match up with the number of images given'
2697
2746
2698
- with autocast (enabled = False ):
2747
+ with autocast ('cuda' , enabled = False ):
2699
2748
text_embeds , text_masks = self .encode_text (texts , return_attn_mask = True )
2700
2749
2701
2750
text_embeds , text_masks = map (lambda t : t .to (images .device ), (text_embeds , text_masks ))
0 commit comments