@@ -34,15 +34,15 @@ class FlaxAutoencoderKLOutput(BaseOutput):
34
34
Output of AutoencoderKL encoding method.
35
35
36
36
Args:
37
- latent_dist (`DiagonalGaussianDistribution `):
38
- Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution `.
39
- `DiagonalGaussianDistribution ` allows for sampling latents from the distribution.
37
+ latent_dist (`FlaxDiagonalGaussianDistribution `):
38
+ Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution `.
39
+ `FlaxDiagonalGaussianDistribution ` allows for sampling latents from the distribution.
40
40
"""
41
41
42
- latent_dist : "DiagonalGaussianDistribution "
42
+ latent_dist : "FlaxDiagonalGaussianDistribution "
43
43
44
44
45
- class Upsample2D (nn .Module ):
45
+ class FlaxUpsample2D (nn .Module ):
46
46
in_channels : int
47
47
dtype : jnp .dtype = jnp .float32
48
48
@@ -66,7 +66,7 @@ def __call__(self, hidden_states):
66
66
return hidden_states
67
67
68
68
69
- class Downsample2D (nn .Module ):
69
+ class FlaxDownsample2D (nn .Module ):
70
70
in_channels : int
71
71
dtype : jnp .dtype = jnp .float32
72
72
@@ -86,7 +86,7 @@ def __call__(self, hidden_states):
86
86
return hidden_states
87
87
88
88
89
- class ResnetBlock2D (nn .Module ):
89
+ class FlaxResnetBlock2D (nn .Module ):
90
90
in_channels : int
91
91
out_channels : int = None
92
92
dropout_prob : float = 0.0
@@ -144,7 +144,7 @@ def __call__(self, hidden_states, deterministic=True):
144
144
return hidden_states + residual
145
145
146
146
147
- class AttentionBlock (nn .Module ):
147
+ class FlaxAttentionBlock (nn .Module ):
148
148
channels : int
149
149
num_head_channels : int = None
150
150
dtype : jnp .dtype = jnp .float32
@@ -201,7 +201,7 @@ def __call__(self, hidden_states):
201
201
return hidden_states
202
202
203
203
204
- class DownEncoderBlock2D (nn .Module ):
204
+ class FlaxDownEncoderBlock2D (nn .Module ):
205
205
in_channels : int
206
206
out_channels : int
207
207
dropout : float = 0.0
@@ -214,7 +214,7 @@ def setup(self):
214
214
for i in range (self .num_layers ):
215
215
in_channels = self .in_channels if i == 0 else self .out_channels
216
216
217
- res_block = ResnetBlock2D (
217
+ res_block = FlaxResnetBlock2D (
218
218
in_channels = in_channels ,
219
219
out_channels = self .out_channels ,
220
220
dropout_prob = self .dropout ,
@@ -224,19 +224,19 @@ def setup(self):
224
224
self .resnets = resnets
225
225
226
226
if self .add_downsample :
227
- self .downsample = Downsample2D (self .out_channels , dtype = self .dtype )
227
+ self .downsamplers_0 = FlaxDownsample2D (self .out_channels , dtype = self .dtype )
228
228
229
229
def __call__ (self , hidden_states , deterministic = True ):
230
230
for resnet in self .resnets :
231
231
hidden_states = resnet (hidden_states , deterministic = deterministic )
232
232
233
233
if self .add_downsample :
234
- hidden_states = self .downsample (hidden_states )
234
+ hidden_states = self .downsamplers_0 (hidden_states )
235
235
236
236
return hidden_states
237
237
238
238
239
- class UpEncoderBlock2D (nn .Module ):
239
+ class FlaxUpEncoderBlock2D (nn .Module ):
240
240
in_channels : int
241
241
out_channels : int
242
242
dropout : float = 0.0
@@ -248,7 +248,7 @@ def setup(self):
248
248
resnets = []
249
249
for i in range (self .num_layers ):
250
250
in_channels = self .in_channels if i == 0 else self .out_channels
251
- res_block = ResnetBlock2D (
251
+ res_block = FlaxResnetBlock2D (
252
252
in_channels = in_channels ,
253
253
out_channels = self .out_channels ,
254
254
dropout_prob = self .dropout ,
@@ -259,19 +259,19 @@ def setup(self):
259
259
self .resnets = resnets
260
260
261
261
if self .add_upsample :
262
- self .upsample = Upsample2D (self .out_channels , dtype = self .dtype )
262
+ self .upsamplers_0 = FlaxUpsample2D (self .out_channels , dtype = self .dtype )
263
263
264
264
def __call__ (self , hidden_states , deterministic = True ):
265
265
for resnet in self .resnets :
266
266
hidden_states = resnet (hidden_states , deterministic = deterministic )
267
267
268
268
if self .add_upsample :
269
- hidden_states = self .upsample (hidden_states )
269
+ hidden_states = self .upsamplers_0 (hidden_states )
270
270
271
271
return hidden_states
272
272
273
273
274
- class UNetMidBlock2D (nn .Module ):
274
+ class FlaxUNetMidBlock2D (nn .Module ):
275
275
in_channels : int
276
276
dropout : float = 0.0
277
277
num_layers : int = 1
@@ -281,7 +281,7 @@ class UNetMidBlock2D(nn.Module):
281
281
def setup (self ):
282
282
# there is always at least one resnet
283
283
resnets = [
284
- ResnetBlock2D (
284
+ FlaxResnetBlock2D (
285
285
in_channels = self .in_channels ,
286
286
out_channels = self .in_channels ,
287
287
dropout_prob = self .dropout ,
@@ -292,12 +292,12 @@ def setup(self):
292
292
attentions = []
293
293
294
294
for _ in range (self .num_layers ):
295
- attn_block = AttentionBlock (
295
+ attn_block = FlaxAttentionBlock (
296
296
channels = self .in_channels , num_head_channels = self .attn_num_head_channels , dtype = self .dtype
297
297
)
298
298
attentions .append (attn_block )
299
299
300
- res_block = ResnetBlock2D (
300
+ res_block = FlaxResnetBlock2D (
301
301
in_channels = self .in_channels ,
302
302
out_channels = self .in_channels ,
303
303
dropout_prob = self .dropout ,
@@ -317,7 +317,7 @@ def __call__(self, hidden_states, deterministic=True):
317
317
return hidden_states
318
318
319
319
320
- class Encoder (nn .Module ):
320
+ class FlaxEncoder (nn .Module ):
321
321
in_channels : int = 3
322
322
out_channels : int = 3
323
323
down_block_types : Tuple [str ] = ("DownEncoderBlock2D" ,)
@@ -347,7 +347,7 @@ def setup(self):
347
347
output_channel = block_out_channels [i ]
348
348
is_final_block = i == len (block_out_channels ) - 1
349
349
350
- down_block = DownEncoderBlock2D (
350
+ down_block = FlaxDownEncoderBlock2D (
351
351
in_channels = input_channel ,
352
352
out_channels = output_channel ,
353
353
num_layers = self .layers_per_block ,
@@ -358,7 +358,7 @@ def setup(self):
358
358
self .down_blocks = down_blocks
359
359
360
360
# middle
361
- self .mid_block = UNetMidBlock2D (
361
+ self .mid_block = FlaxUNetMidBlock2D (
362
362
in_channels = block_out_channels [- 1 ], attn_num_head_channels = None , dtype = self .dtype
363
363
)
364
364
@@ -392,7 +392,7 @@ def __call__(self, sample, deterministic: bool = True):
392
392
return sample
393
393
394
394
395
- class Decoder (nn .Module ):
395
+ class FlaxDecoder (nn .Module ):
396
396
dtype : jnp .dtype = jnp .float32
397
397
in_channels : int = 3
398
398
out_channels : int = 3
@@ -415,7 +415,7 @@ def setup(self):
415
415
)
416
416
417
417
# middle
418
- self .mid_block = UNetMidBlock2D (
418
+ self .mid_block = FlaxUNetMidBlock2D (
419
419
in_channels = block_out_channels [- 1 ], attn_num_head_channels = None , dtype = self .dtype
420
420
)
421
421
@@ -429,7 +429,7 @@ def setup(self):
429
429
430
430
is_final_block = i == len (block_out_channels ) - 1
431
431
432
- up_block = UpEncoderBlock2D (
432
+ up_block = FlaxUpEncoderBlock2D (
433
433
in_channels = prev_output_channel ,
434
434
out_channels = output_channel ,
435
435
num_layers = self .layers_per_block + 1 ,
@@ -469,7 +469,7 @@ def __call__(self, sample, deterministic: bool = True):
469
469
return sample
470
470
471
471
472
- class DiagonalGaussianDistribution (object ):
472
+ class FlaxDiagonalGaussianDistribution (object ):
473
473
def __init__ (self , parameters , deterministic = False ):
474
474
# Last axis to account for channels-last
475
475
self .mean , self .logvar = jnp .split (parameters , 2 , axis = - 1 )
@@ -521,7 +521,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
521
521
dtype : jnp .dtype = jnp .float32
522
522
523
523
def setup (self ):
524
- self .encoder = Encoder (
524
+ self .encoder = FlaxEncoder (
525
525
in_channels = self .config .in_channels ,
526
526
out_channels = self .config .latent_channels ,
527
527
down_block_types = self .config .down_block_types ,
@@ -532,7 +532,7 @@ def setup(self):
532
532
double_z = True ,
533
533
dtype = self .dtype ,
534
534
)
535
- self .decoder = Decoder (
535
+ self .decoder = FlaxDecoder (
536
536
in_channels = self .config .latent_channels ,
537
537
out_channels = self .config .out_channels ,
538
538
up_block_types = self .config .up_block_types ,
@@ -572,7 +572,7 @@ def encode(self, sample, deterministic: bool = True, return_dict: bool = True):
572
572
573
573
hidden_states = self .encoder (sample , deterministic = deterministic )
574
574
moments = self .quant_conv (hidden_states )
575
- posterior = DiagonalGaussianDistribution (moments )
575
+ posterior = FlaxDiagonalGaussianDistribution (moments )
576
576
577
577
if not return_dict :
578
578
return (posterior ,)
0 commit comments