@@ -89,7 +89,7 @@ def __call__(self, hidden_states):
89
89
class FlaxResnetBlock2D (nn .Module ):
90
90
in_channels : int
91
91
out_channels : int = None
92
- dropout_prob : float = 0.0
92
+ dropout : float = 0.0
93
93
use_nin_shortcut : bool = None
94
94
dtype : jnp .dtype = jnp .float32
95
95
@@ -106,7 +106,7 @@ def setup(self):
106
106
)
107
107
108
108
self .norm2 = nn .GroupNorm (num_groups = 32 , epsilon = 1e-6 )
109
- self .dropout = nn .Dropout (self .dropout_prob )
109
+ self .dropout_layer = nn .Dropout (self .dropout )
110
110
self .conv2 = nn .Conv (
111
111
out_channels ,
112
112
kernel_size = (3 , 3 ),
@@ -135,7 +135,7 @@ def __call__(self, hidden_states, deterministic=True):
135
135
136
136
hidden_states = self .norm2 (hidden_states )
137
137
hidden_states = nn .swish (hidden_states )
138
- hidden_states = self .dropout (hidden_states , deterministic )
138
+ hidden_states = self .dropout_layer (hidden_states , deterministic )
139
139
hidden_states = self .conv2 (hidden_states )
140
140
141
141
if self .conv_shortcut is not None :
@@ -217,7 +217,7 @@ def setup(self):
217
217
res_block = FlaxResnetBlock2D (
218
218
in_channels = in_channels ,
219
219
out_channels = self .out_channels ,
220
- dropout_prob = self .dropout ,
220
+ dropout = self .dropout ,
221
221
dtype = self .dtype ,
222
222
)
223
223
resnets .append (res_block )
@@ -251,7 +251,7 @@ def setup(self):
251
251
res_block = FlaxResnetBlock2D (
252
252
in_channels = in_channels ,
253
253
out_channels = self .out_channels ,
254
- dropout_prob = self .dropout ,
254
+ dropout = self .dropout ,
255
255
dtype = self .dtype ,
256
256
)
257
257
resnets .append (res_block )
@@ -284,7 +284,7 @@ def setup(self):
284
284
FlaxResnetBlock2D (
285
285
in_channels = self .in_channels ,
286
286
out_channels = self .in_channels ,
287
- dropout_prob = self .dropout ,
287
+ dropout = self .dropout ,
288
288
dtype = self .dtype ,
289
289
)
290
290
]
@@ -300,7 +300,7 @@ def setup(self):
300
300
res_block = FlaxResnetBlock2D (
301
301
in_channels = self .in_channels ,
302
302
out_channels = self .in_channels ,
303
- dropout_prob = self .dropout ,
303
+ dropout = self .dropout ,
304
304
dtype = self .dtype ,
305
305
)
306
306
resnets .append (res_block )
0 commit comments