Skip to content

Commit 1f196a0

Browse files
Changed variable name from "h" to "hidden_states" (#285)
* Changed variable name from "h" to "hidden_states" Per issue #198 , changed variable name from "h" to "hidden_states" in the forward function only. I am happy to change any other variable names, please advise recommended new names. * Update src/diffusers/models/resnet.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 034673b commit 1f196a0

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

src/diffusers/models/resnet.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -328,39 +328,39 @@ def __init__(
328328
if self.use_nin_shortcut:
329329
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
330330

331-
def forward(self, x, temb, hey=False):
332-
h = x
331+
def forward(self, x, temb):
332+
hidden_states = x
333333

334334
# make sure hidden states is in float32
335335
# when running in half-precision
336-
h = self.norm1(h.float()).type(h.dtype)
337-
h = self.nonlinearity(h)
336+
hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype)
337+
hidden_states = self.nonlinearity(hidden_states)
338338

339339
if self.upsample is not None:
340340
x = self.upsample(x)
341-
h = self.upsample(h)
341+
hidden_states = self.upsample(hidden_states)
342342
elif self.downsample is not None:
343343
x = self.downsample(x)
344-
h = self.downsample(h)
344+
hidden_states = self.downsample(hidden_states)
345345

346-
h = self.conv1(h)
346+
hidden_states = self.conv1(hidden_states)
347347

348348
if temb is not None:
349349
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
350-
h = h + temb
350+
hidden_states = hidden_states + temb
351351

352352
# make sure hidden states is in float32
353353
# when running in half-precision
354-
h = self.norm2(h.float()).type(h.dtype)
355-
h = self.nonlinearity(h)
354+
hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype)
355+
hidden_states = self.nonlinearity(hidden_states)
356356

357-
h = self.dropout(h)
358-
h = self.conv2(h)
357+
hidden_states = self.dropout(hidden_states)
358+
hidden_states = self.conv2(hidden_states)
359359

360360
if self.conv_shortcut is not None:
361361
x = self.conv_shortcut(x)
362362

363-
out = (x + h) / self.output_scale_factor
363+
out = (x + hidden_states) / self.output_scale_factor
364364

365365
return out
366366

0 commit comments

Comments
 (0)