Skip to content

Commit 3fc8ef7

Browse files
Replace dropout_prob by dropout in vae (#595)
replace `dropout_prob` by `dropout` in `vae`
1 parent 8685699 commit 3fc8ef7

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/diffusers/models/vae_flax.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __call__(self, hidden_states):
8989
class FlaxResnetBlock2D(nn.Module):
9090
in_channels: int
9191
out_channels: int = None
92-
dropout_prob: float = 0.0
92+
dropout: float = 0.0
9393
use_nin_shortcut: bool = None
9494
dtype: jnp.dtype = jnp.float32
9595

@@ -106,7 +106,7 @@ def setup(self):
106106
)
107107

108108
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
109-
self.dropout = nn.Dropout(self.dropout_prob)
109+
self.dropout_layer = nn.Dropout(self.dropout)
110110
self.conv2 = nn.Conv(
111111
out_channels,
112112
kernel_size=(3, 3),
@@ -135,7 +135,7 @@ def __call__(self, hidden_states, deterministic=True):
135135

136136
hidden_states = self.norm2(hidden_states)
137137
hidden_states = nn.swish(hidden_states)
138-
hidden_states = self.dropout(hidden_states, deterministic)
138+
hidden_states = self.dropout_layer(hidden_states, deterministic)
139139
hidden_states = self.conv2(hidden_states)
140140

141141
if self.conv_shortcut is not None:
@@ -217,7 +217,7 @@ def setup(self):
217217
res_block = FlaxResnetBlock2D(
218218
in_channels=in_channels,
219219
out_channels=self.out_channels,
220-
dropout_prob=self.dropout,
220+
dropout=self.dropout,
221221
dtype=self.dtype,
222222
)
223223
resnets.append(res_block)
@@ -251,7 +251,7 @@ def setup(self):
251251
res_block = FlaxResnetBlock2D(
252252
in_channels=in_channels,
253253
out_channels=self.out_channels,
254-
dropout_prob=self.dropout,
254+
dropout=self.dropout,
255255
dtype=self.dtype,
256256
)
257257
resnets.append(res_block)
@@ -284,7 +284,7 @@ def setup(self):
284284
FlaxResnetBlock2D(
285285
in_channels=self.in_channels,
286286
out_channels=self.in_channels,
287-
dropout_prob=self.dropout,
287+
dropout=self.dropout,
288288
dtype=self.dtype,
289289
)
290290
]
@@ -300,7 +300,7 @@ def setup(self):
300300
res_block = FlaxResnetBlock2D(
301301
in_channels=self.in_channels,
302302
out_channels=self.in_channels,
303-
dropout_prob=self.dropout,
303+
dropout=self.dropout,
304304
dtype=self.dtype,
305305
)
306306
resnets.append(res_block)

0 commit comments

Comments
 (0)