Skip to content

Commit 7265dd8

Browse files
i-am-epicNikhil A Vpatrickvonplatenpatil-suraj
authored
renamed x to meaningful variable in resnet.py (#677)
* renamed single letter variables * renamed x to meaningful variable in resnet.py Hello @patil-suraj can you verify it Thanks * Reformatted using black * renamed x to meaningful variable in resnet.py Hello @patil-suraj can you verify it Thanks * reformatted the files * modified unboundlocalerror in line 374 * removed referenced before error * renamed single variable x -> hidden_state, p-> pad_value Co-authored-by: Nikhil A V <nikhilav@Nikhils-MacBook-Pro.local> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
1 parent 14b9754 commit 7265dd8

File tree

1 file changed

+66
-37
lines changed

1 file changed

+66
-37
lines changed

src/diffusers/models/resnet.py

Lines changed: 66 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
112112
self.fir_kernel = fir_kernel
113113
self.out_channels = out_channels
114114

115-
def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
115+
def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
116116
"""Fused `upsample_2d()` followed by `Conv2d()`.
117117
118118
Args:
@@ -151,34 +151,46 @@ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
151151
convW = weight.shape[3]
152152
inC = weight.shape[1]
153153

154-
p = (kernel.shape[0] - factor) - (convW - 1)
154+
pad_value = (kernel.shape[0] - factor) - (convW - 1)
155155

156156
stride = (factor, factor)
157157
# Determine data dimensions.
158-
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
158+
output_shape = (
159+
(hidden_states.shape[2] - 1) * factor + convH,
160+
(hidden_states.shape[3] - 1) * factor + convW,
161+
)
159162
output_padding = (
160-
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
161-
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
163+
output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
164+
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
162165
)
163166
assert output_padding[0] >= 0 and output_padding[1] >= 0
164167
inC = weight.shape[1]
165-
num_groups = x.shape[1] // inC
168+
num_groups = hidden_states.shape[1] // inC
166169

167170
# Transpose weights.
168171
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
169172
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
170173
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
171174

172-
x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
175+
inverse_conv = F.conv_transpose2d(
176+
hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
177+
)
173178

174-
x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
179+
output = upfirdn2d_native(
180+
inverse_conv,
181+
torch.tensor(kernel, device=inverse_conv.device),
182+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
183+
)
175184
else:
176-
p = kernel.shape[0] - factor
177-
x = upfirdn2d_native(
178-
x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
185+
pad_value = kernel.shape[0] - factor
186+
output = upfirdn2d_native(
187+
hidden_states,
188+
torch.tensor(kernel, device=hidden_states.device),
189+
up=factor,
190+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
179191
)
180192

181-
return x
193+
return output
182194

183195
def forward(self, hidden_states):
184196
if self.use_conv:
@@ -200,7 +212,7 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
200212
self.use_conv = use_conv
201213
self.out_channels = out_channels
202214

203-
def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
215+
def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
204216
"""Fused `Conv2d()` followed by `downsample_2d()`.
205217
206218
Args:
@@ -232,20 +244,29 @@ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
232244

233245
if self.use_conv:
234246
_, _, convH, convW = weight.shape
235-
p = (kernel.shape[0] - factor) + (convW - 1)
236-
s = [factor, factor]
237-
x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
238-
x = F.conv2d(x, weight, stride=s, padding=0)
247+
pad_value = (kernel.shape[0] - factor) + (convW - 1)
248+
stride_value = [factor, factor]
249+
upfirdn_input = upfirdn2d_native(
250+
hidden_states,
251+
torch.tensor(kernel, device=hidden_states.device),
252+
pad=((pad_value + 1) // 2, pad_value // 2),
253+
)
254+
hidden_states = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
239255
else:
240-
p = kernel.shape[0] - factor
241-
x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
256+
pad_value = kernel.shape[0] - factor
257+
hidden_states = upfirdn2d_native(
258+
hidden_states,
259+
torch.tensor(kernel, device=hidden_states.device),
260+
down=factor,
261+
pad=((pad_value + 1) // 2, pad_value // 2),
262+
)
242263

243-
return x
264+
return hidden_states
244265

245266
def forward(self, hidden_states):
246267
if self.use_conv:
247-
hidden_states = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
248-
hidden_states = hidden_states + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
268+
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
269+
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
249270
else:
250271
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
251272

@@ -332,17 +353,17 @@ def __init__(
332353
if self.use_in_shortcut:
333354
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
334355

335-
def forward(self, x, temb):
336-
hidden_states = x
356+
def forward(self, input_tensor, temb):
357+
hidden_states = input_tensor
337358

338359
hidden_states = self.norm1(hidden_states)
339360
hidden_states = self.nonlinearity(hidden_states)
340361

341362
if self.upsample is not None:
342-
x = self.upsample(x)
363+
input_tensor = self.upsample(input_tensor)
343364
hidden_states = self.upsample(hidden_states)
344365
elif self.downsample is not None:
345-
x = self.downsample(x)
366+
input_tensor = self.downsample(input_tensor)
346367
hidden_states = self.downsample(hidden_states)
347368

348369
hidden_states = self.conv1(hidden_states)
@@ -358,19 +379,19 @@ def forward(self, x, temb):
358379
hidden_states = self.conv2(hidden_states)
359380

360381
if self.conv_shortcut is not None:
361-
x = self.conv_shortcut(x)
382+
input_tensor = self.conv_shortcut(input_tensor)
362383

363-
out = (x + hidden_states) / self.output_scale_factor
384+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
364385

365-
return out
386+
return output_tensor
366387

367388

368389
class Mish(torch.nn.Module):
369-
def forward(self, x):
370-
return x * torch.tanh(torch.nn.functional.softplus(x))
390+
def forward(self, hidden_states):
391+
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
371392

372393

373-
def upsample_2d(x, kernel=None, factor=2, gain=1):
394+
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
374395
r"""Upsample2D a batch of 2D images with the given filter.
375396
376397
Args:
@@ -397,11 +418,16 @@ def upsample_2d(x, kernel=None, factor=2, gain=1):
397418
kernel /= torch.sum(kernel)
398419

399420
kernel = kernel * (gain * (factor**2))
400-
p = kernel.shape[0] - factor
401-
return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
421+
pad_value = kernel.shape[0] - factor
422+
return upfirdn2d_native(
423+
hidden_states,
424+
kernel.to(device=hidden_states.device),
425+
up=factor,
426+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
427+
)
402428

403429

404-
def downsample_2d(x, kernel=None, factor=2, gain=1):
430+
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
405431
r"""Downsample2D a batch of 2D images with the given filter.
406432
407433
Args:
@@ -429,8 +455,10 @@ def downsample_2d(x, kernel=None, factor=2, gain=1):
429455
kernel /= torch.sum(kernel)
430456

431457
kernel = kernel * gain
432-
p = kernel.shape[0] - factor
433-
return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
458+
pad_value = kernel.shape[0] - factor
459+
return upfirdn2d_native(
460+
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
461+
)
434462

435463

436464
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
@@ -441,6 +469,7 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
441469

442470
_, channel, in_h, in_w = input.shape
443471
input = input.reshape(-1, in_h, in_w, 1)
472+
# Rename this variable (input); it shadows a builtin.sonarlint(python:S5806)
444473

445474
_, in_h, in_w, minor = input.shape
446475
kernel_h, kernel_w = kernel.shape

0 commit comments

Comments
 (0)