diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index eba9259b8201..1471e715bc7c 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -89,7 +89,7 @@ def __call__(self, hidden_states): class FlaxResnetBlock2D(nn.Module): in_channels: int out_channels: int = None - dropout_prob: float = 0.0 + dropout: float = 0.0 use_nin_shortcut: bool = None dtype: jnp.dtype = jnp.float32 @@ -106,7 +106,7 @@ def setup(self): ) self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6) - self.dropout = nn.Dropout(self.dropout_prob) + self.dropout_layer = nn.Dropout(self.dropout) self.conv2 = nn.Conv( out_channels, kernel_size=(3, 3), @@ -135,7 +135,7 @@ def __call__(self, hidden_states, deterministic=True): hidden_states = self.norm2(hidden_states) hidden_states = nn.swish(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic) + hidden_states = self.dropout_layer(hidden_states, deterministic) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: @@ -217,7 +217,7 @@ def setup(self): res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, - dropout_prob=self.dropout, + dropout=self.dropout, dtype=self.dtype, ) resnets.append(res_block) @@ -251,7 +251,7 @@ def setup(self): res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, - dropout_prob=self.dropout, + dropout=self.dropout, dtype=self.dtype, ) resnets.append(res_block) @@ -284,7 +284,7 @@ def setup(self): FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, - dropout_prob=self.dropout, + dropout=self.dropout, dtype=self.dtype, ) ] @@ -300,7 +300,7 @@ def setup(self): res_block = FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, - dropout_prob=self.dropout, + dropout=self.dropout, dtype=self.dtype, ) resnets.append(res_block)