@@ -112,7 +112,7 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
112
112
self .fir_kernel = fir_kernel
113
113
self .out_channels = out_channels
114
114
115
- def _upsample_2d (self , x , weight = None , kernel = None , factor = 2 , gain = 1 ):
115
+ def _upsample_2d (self , hidden_states , weight = None , kernel = None , factor = 2 , gain = 1 ):
116
116
"""Fused `upsample_2d()` followed by `Conv2d()`.
117
117
118
118
Args:
@@ -151,34 +151,46 @@ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
151
151
convW = weight .shape [3 ]
152
152
inC = weight .shape [1 ]
153
153
154
- p = (kernel .shape [0 ] - factor ) - (convW - 1 )
154
+ pad_value = (kernel .shape [0 ] - factor ) - (convW - 1 )
155
155
156
156
stride = (factor , factor )
157
157
# Determine data dimensions.
158
- output_shape = ((x .shape [2 ] - 1 ) * factor + convH , (x .shape [3 ] - 1 ) * factor + convW )
158
+ output_shape = (
159
+ (hidden_states .shape [2 ] - 1 ) * factor + convH ,
160
+ (hidden_states .shape [3 ] - 1 ) * factor + convW ,
161
+ )
159
162
output_padding = (
160
- output_shape [0 ] - (x .shape [2 ] - 1 ) * stride [0 ] - convH ,
161
- output_shape [1 ] - (x .shape [3 ] - 1 ) * stride [1 ] - convW ,
163
+ output_shape [0 ] - (hidden_states .shape [2 ] - 1 ) * stride [0 ] - convH ,
164
+ output_shape [1 ] - (hidden_states .shape [3 ] - 1 ) * stride [1 ] - convW ,
162
165
)
163
166
assert output_padding [0 ] >= 0 and output_padding [1 ] >= 0
164
167
inC = weight .shape [1 ]
165
- num_groups = x .shape [1 ] // inC
168
+ num_groups = hidden_states .shape [1 ] // inC
166
169
167
170
# Transpose weights.
168
171
weight = torch .reshape (weight , (num_groups , - 1 , inC , convH , convW ))
169
172
weight = torch .flip (weight , dims = [3 , 4 ]).permute (0 , 2 , 1 , 3 , 4 )
170
173
weight = torch .reshape (weight , (num_groups * inC , - 1 , convH , convW ))
171
174
172
- x = F .conv_transpose2d (x , weight , stride = stride , output_padding = output_padding , padding = 0 )
175
+ inverse_conv = F .conv_transpose2d (
176
+ hidden_states , weight , stride = stride , output_padding = output_padding , padding = 0
177
+ )
173
178
174
- x = upfirdn2d_native (x , torch .tensor (kernel , device = x .device ), pad = ((p + 1 ) // 2 + factor - 1 , p // 2 + 1 ))
179
+ output = upfirdn2d_native (
180
+ inverse_conv ,
181
+ torch .tensor (kernel , device = inverse_conv .device ),
182
+ pad = ((pad_value + 1 ) // 2 + factor - 1 , pad_value // 2 + 1 ),
183
+ )
175
184
else :
176
- p = kernel .shape [0 ] - factor
177
- x = upfirdn2d_native (
178
- x , torch .tensor (kernel , device = x .device ), up = factor , pad = ((p + 1 ) // 2 + factor - 1 , p // 2 )
185
+ pad_value = kernel .shape [0 ] - factor
186
+ output = upfirdn2d_native (
187
+ hidden_states ,
188
+ torch .tensor (kernel , device = hidden_states .device ),
189
+ up = factor ,
190
+ pad = ((pad_value + 1 ) // 2 + factor - 1 , pad_value // 2 ),
179
191
)
180
192
181
- return x
193
+ return output
182
194
183
195
def forward (self , hidden_states ):
184
196
if self .use_conv :
@@ -200,7 +212,7 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
200
212
self .use_conv = use_conv
201
213
self .out_channels = out_channels
202
214
203
- def _downsample_2d (self , x , weight = None , kernel = None , factor = 2 , gain = 1 ):
215
+ def _downsample_2d (self , hidden_states , weight = None , kernel = None , factor = 2 , gain = 1 ):
204
216
"""Fused `Conv2d()` followed by `downsample_2d()`.
205
217
206
218
Args:
@@ -232,20 +244,29 @@ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
232
244
233
245
if self .use_conv :
234
246
_ , _ , convH , convW = weight .shape
235
- p = (kernel .shape [0 ] - factor ) + (convW - 1 )
236
- s = [factor , factor ]
237
- x = upfirdn2d_native (x , torch .tensor (kernel , device = x .device ), pad = ((p + 1 ) // 2 , p // 2 ))
238
- x = F .conv2d (x , weight , stride = s , padding = 0 )
247
+ pad_value = (kernel .shape [0 ] - factor ) + (convW - 1 )
248
+ stride_value = [factor , factor ]
249
+ upfirdn_input = upfirdn2d_native (
250
+ hidden_states ,
251
+ torch .tensor (kernel , device = hidden_states .device ),
252
+ pad = ((pad_value + 1 ) // 2 , pad_value // 2 ),
253
+ )
254
+ hidden_states = F .conv2d (upfirdn_input , weight , stride = stride_value , padding = 0 )
239
255
else :
240
- p = kernel .shape [0 ] - factor
241
- x = upfirdn2d_native (x , torch .tensor (kernel , device = x .device ), down = factor , pad = ((p + 1 ) // 2 , p // 2 ))
256
+ pad_value = kernel .shape [0 ] - factor
257
+ hidden_states = upfirdn2d_native (
258
+ hidden_states ,
259
+ torch .tensor (kernel , device = hidden_states .device ),
260
+ down = factor ,
261
+ pad = ((pad_value + 1 ) // 2 , pad_value // 2 ),
262
+ )
242
263
243
- return x
264
+ return hidden_states
244
265
245
266
def forward (self , hidden_states ):
246
267
if self .use_conv :
247
- hidden_states = self ._downsample_2d (hidden_states , weight = self .Conv2d_0 .weight , kernel = self .fir_kernel )
248
- hidden_states = hidden_states + self .Conv2d_0 .bias .reshape (1 , - 1 , 1 , 1 )
268
+ downsample_input = self ._downsample_2d (hidden_states , weight = self .Conv2d_0 .weight , kernel = self .fir_kernel )
269
+ hidden_states = downsample_input + self .Conv2d_0 .bias .reshape (1 , - 1 , 1 , 1 )
249
270
else :
250
271
hidden_states = self ._downsample_2d (hidden_states , kernel = self .fir_kernel , factor = 2 )
251
272
@@ -332,17 +353,17 @@ def __init__(
332
353
if self .use_in_shortcut :
333
354
self .conv_shortcut = torch .nn .Conv2d (in_channels , out_channels , kernel_size = 1 , stride = 1 , padding = 0 )
334
355
335
- def forward (self , x , temb ):
336
- hidden_states = x
356
+ def forward (self , input_tensor , temb ):
357
+ hidden_states = input_tensor
337
358
338
359
hidden_states = self .norm1 (hidden_states )
339
360
hidden_states = self .nonlinearity (hidden_states )
340
361
341
362
if self .upsample is not None :
342
- x = self .upsample (x )
363
+ input_tensor = self .upsample (input_tensor )
343
364
hidden_states = self .upsample (hidden_states )
344
365
elif self .downsample is not None :
345
- x = self .downsample (x )
366
+ input_tensor = self .downsample (input_tensor )
346
367
hidden_states = self .downsample (hidden_states )
347
368
348
369
hidden_states = self .conv1 (hidden_states )
@@ -358,19 +379,19 @@ def forward(self, x, temb):
358
379
hidden_states = self .conv2 (hidden_states )
359
380
360
381
if self .conv_shortcut is not None :
361
- x = self .conv_shortcut (x )
382
+ input_tensor = self .conv_shortcut (input_tensor )
362
383
363
- out = (x + hidden_states ) / self .output_scale_factor
384
+ output_tensor = (input_tensor + hidden_states ) / self .output_scale_factor
364
385
365
- return out
386
+ return output_tensor
366
387
367
388
368
389
class Mish (torch .nn .Module ):
369
- def forward (self , x ):
370
- return x * torch .tanh (torch .nn .functional .softplus (x ))
390
+ def forward (self , hidden_states ):
391
+ return hidden_states * torch .tanh (torch .nn .functional .softplus (hidden_states ))
371
392
372
393
373
- def upsample_2d (x , kernel = None , factor = 2 , gain = 1 ):
394
+ def upsample_2d (hidden_states , kernel = None , factor = 2 , gain = 1 ):
374
395
r"""Upsample2D a batch of 2D images with the given filter.
375
396
376
397
Args:
@@ -397,11 +418,16 @@ def upsample_2d(x, kernel=None, factor=2, gain=1):
397
418
kernel /= torch .sum (kernel )
398
419
399
420
kernel = kernel * (gain * (factor ** 2 ))
400
- p = kernel .shape [0 ] - factor
401
- return upfirdn2d_native (x , kernel .to (device = x .device ), up = factor , pad = ((p + 1 ) // 2 + factor - 1 , p // 2 ))
421
+ pad_value = kernel .shape [0 ] - factor
422
+ return upfirdn2d_native (
423
+ hidden_states ,
424
+ kernel .to (device = hidden_states .device ),
425
+ up = factor ,
426
+ pad = ((pad_value + 1 ) // 2 + factor - 1 , pad_value // 2 ),
427
+ )
402
428
403
429
404
- def downsample_2d (x , kernel = None , factor = 2 , gain = 1 ):
430
+ def downsample_2d (hidden_states , kernel = None , factor = 2 , gain = 1 ):
405
431
r"""Downsample2D a batch of 2D images with the given filter.
406
432
407
433
Args:
@@ -429,8 +455,10 @@ def downsample_2d(x, kernel=None, factor=2, gain=1):
429
455
kernel /= torch .sum (kernel )
430
456
431
457
kernel = kernel * gain
432
- p = kernel .shape [0 ] - factor
433
- return upfirdn2d_native (x , kernel .to (device = x .device ), down = factor , pad = ((p + 1 ) // 2 , p // 2 ))
458
+ pad_value = kernel .shape [0 ] - factor
459
+ return upfirdn2d_native (
460
+ hidden_states , kernel .to (device = hidden_states .device ), down = factor , pad = ((pad_value + 1 ) // 2 , pad_value // 2 )
461
+ )
434
462
435
463
436
464
def upfirdn2d_native (input , kernel , up = 1 , down = 1 , pad = (0 , 0 )):
@@ -441,6 +469,7 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
441
469
442
470
_ , channel , in_h , in_w = input .shape
443
471
input = input .reshape (- 1 , in_h , in_w , 1 )
472
+ # Rename this variable (input); it shadows a builtin.sonarlint(python:S5806)
444
473
445
474
_ , in_h , in_w , minor = input .shape
446
475
kernel_h , kernel_w = kernel .shape
0 commit comments