Skip to content

Commit c01ec2d

Browse files
authored
[FlaxAutoencoderKL] rename weights to align with PT (#584)
* rename weights to align with PT * DiagonalGaussianDistribution => FlaxDiagonalGaussianDistribution * fix name
1 parent 0902449 commit c01ec2d

File tree

1 file changed

+30
-30
lines changed

1 file changed

+30
-30
lines changed

src/diffusers/models/vae_flax.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ class FlaxAutoencoderKLOutput(BaseOutput):
3434
Output of AutoencoderKL encoding method.
3535
3636
Args:
37-
latent_dist (`DiagonalGaussianDistribution`):
38-
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
39-
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
37+
latent_dist (`FlaxDiagonalGaussianDistribution`):
38+
Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`.
39+
`FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution.
4040
"""
4141

42-
latent_dist: "DiagonalGaussianDistribution"
42+
latent_dist: "FlaxDiagonalGaussianDistribution"
4343

4444

45-
class Upsample2D(nn.Module):
45+
class FlaxUpsample2D(nn.Module):
4646
in_channels: int
4747
dtype: jnp.dtype = jnp.float32
4848

@@ -66,7 +66,7 @@ def __call__(self, hidden_states):
6666
return hidden_states
6767

6868

69-
class Downsample2D(nn.Module):
69+
class FlaxDownsample2D(nn.Module):
7070
in_channels: int
7171
dtype: jnp.dtype = jnp.float32
7272

@@ -86,7 +86,7 @@ def __call__(self, hidden_states):
8686
return hidden_states
8787

8888

89-
class ResnetBlock2D(nn.Module):
89+
class FlaxResnetBlock2D(nn.Module):
9090
in_channels: int
9191
out_channels: int = None
9292
dropout_prob: float = 0.0
@@ -144,7 +144,7 @@ def __call__(self, hidden_states, deterministic=True):
144144
return hidden_states + residual
145145

146146

147-
class AttentionBlock(nn.Module):
147+
class FlaxAttentionBlock(nn.Module):
148148
channels: int
149149
num_head_channels: int = None
150150
dtype: jnp.dtype = jnp.float32
@@ -201,7 +201,7 @@ def __call__(self, hidden_states):
201201
return hidden_states
202202

203203

204-
class DownEncoderBlock2D(nn.Module):
204+
class FlaxDownEncoderBlock2D(nn.Module):
205205
in_channels: int
206206
out_channels: int
207207
dropout: float = 0.0
@@ -214,7 +214,7 @@ def setup(self):
214214
for i in range(self.num_layers):
215215
in_channels = self.in_channels if i == 0 else self.out_channels
216216

217-
res_block = ResnetBlock2D(
217+
res_block = FlaxResnetBlock2D(
218218
in_channels=in_channels,
219219
out_channels=self.out_channels,
220220
dropout_prob=self.dropout,
@@ -224,19 +224,19 @@ def setup(self):
224224
self.resnets = resnets
225225

226226
if self.add_downsample:
227-
self.downsample = Downsample2D(self.out_channels, dtype=self.dtype)
227+
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
228228

229229
def __call__(self, hidden_states, deterministic=True):
230230
for resnet in self.resnets:
231231
hidden_states = resnet(hidden_states, deterministic=deterministic)
232232

233233
if self.add_downsample:
234-
hidden_states = self.downsample(hidden_states)
234+
hidden_states = self.downsamplers_0(hidden_states)
235235

236236
return hidden_states
237237

238238

239-
class UpEncoderBlock2D(nn.Module):
239+
class FlaxUpEncoderBlock2D(nn.Module):
240240
in_channels: int
241241
out_channels: int
242242
dropout: float = 0.0
@@ -248,7 +248,7 @@ def setup(self):
248248
resnets = []
249249
for i in range(self.num_layers):
250250
in_channels = self.in_channels if i == 0 else self.out_channels
251-
res_block = ResnetBlock2D(
251+
res_block = FlaxResnetBlock2D(
252252
in_channels=in_channels,
253253
out_channels=self.out_channels,
254254
dropout_prob=self.dropout,
@@ -259,19 +259,19 @@ def setup(self):
259259
self.resnets = resnets
260260

261261
if self.add_upsample:
262-
self.upsample = Upsample2D(self.out_channels, dtype=self.dtype)
262+
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
263263

264264
def __call__(self, hidden_states, deterministic=True):
265265
for resnet in self.resnets:
266266
hidden_states = resnet(hidden_states, deterministic=deterministic)
267267

268268
if self.add_upsample:
269-
hidden_states = self.upsample(hidden_states)
269+
hidden_states = self.upsamplers_0(hidden_states)
270270

271271
return hidden_states
272272

273273

274-
class UNetMidBlock2D(nn.Module):
274+
class FlaxUNetMidBlock2D(nn.Module):
275275
in_channels: int
276276
dropout: float = 0.0
277277
num_layers: int = 1
@@ -281,7 +281,7 @@ class UNetMidBlock2D(nn.Module):
281281
def setup(self):
282282
# there is always at least one resnet
283283
resnets = [
284-
ResnetBlock2D(
284+
FlaxResnetBlock2D(
285285
in_channels=self.in_channels,
286286
out_channels=self.in_channels,
287287
dropout_prob=self.dropout,
@@ -292,12 +292,12 @@ def setup(self):
292292
attentions = []
293293

294294
for _ in range(self.num_layers):
295-
attn_block = AttentionBlock(
295+
attn_block = FlaxAttentionBlock(
296296
channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype
297297
)
298298
attentions.append(attn_block)
299299

300-
res_block = ResnetBlock2D(
300+
res_block = FlaxResnetBlock2D(
301301
in_channels=self.in_channels,
302302
out_channels=self.in_channels,
303303
dropout_prob=self.dropout,
@@ -317,7 +317,7 @@ def __call__(self, hidden_states, deterministic=True):
317317
return hidden_states
318318

319319

320-
class Encoder(nn.Module):
320+
class FlaxEncoder(nn.Module):
321321
in_channels: int = 3
322322
out_channels: int = 3
323323
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
@@ -347,7 +347,7 @@ def setup(self):
347347
output_channel = block_out_channels[i]
348348
is_final_block = i == len(block_out_channels) - 1
349349

350-
down_block = DownEncoderBlock2D(
350+
down_block = FlaxDownEncoderBlock2D(
351351
in_channels=input_channel,
352352
out_channels=output_channel,
353353
num_layers=self.layers_per_block,
@@ -358,7 +358,7 @@ def setup(self):
358358
self.down_blocks = down_blocks
359359

360360
# middle
361-
self.mid_block = UNetMidBlock2D(
361+
self.mid_block = FlaxUNetMidBlock2D(
362362
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
363363
)
364364

@@ -392,7 +392,7 @@ def __call__(self, sample, deterministic: bool = True):
392392
return sample
393393

394394

395-
class Decoder(nn.Module):
395+
class FlaxDecoder(nn.Module):
396396
dtype: jnp.dtype = jnp.float32
397397
in_channels: int = 3
398398
out_channels: int = 3
@@ -415,7 +415,7 @@ def setup(self):
415415
)
416416

417417
# middle
418-
self.mid_block = UNetMidBlock2D(
418+
self.mid_block = FlaxUNetMidBlock2D(
419419
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
420420
)
421421

@@ -429,7 +429,7 @@ def setup(self):
429429

430430
is_final_block = i == len(block_out_channels) - 1
431431

432-
up_block = UpEncoderBlock2D(
432+
up_block = FlaxUpEncoderBlock2D(
433433
in_channels=prev_output_channel,
434434
out_channels=output_channel,
435435
num_layers=self.layers_per_block + 1,
@@ -469,7 +469,7 @@ def __call__(self, sample, deterministic: bool = True):
469469
return sample
470470

471471

472-
class DiagonalGaussianDistribution(object):
472+
class FlaxDiagonalGaussianDistribution(object):
473473
def __init__(self, parameters, deterministic=False):
474474
# Last axis to account for channels-last
475475
self.mean, self.logvar = jnp.split(parameters, 2, axis=-1)
@@ -521,7 +521,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
521521
dtype: jnp.dtype = jnp.float32
522522

523523
def setup(self):
524-
self.encoder = Encoder(
524+
self.encoder = FlaxEncoder(
525525
in_channels=self.config.in_channels,
526526
out_channels=self.config.latent_channels,
527527
down_block_types=self.config.down_block_types,
@@ -532,7 +532,7 @@ def setup(self):
532532
double_z=True,
533533
dtype=self.dtype,
534534
)
535-
self.decoder = Decoder(
535+
self.decoder = FlaxDecoder(
536536
in_channels=self.config.latent_channels,
537537
out_channels=self.config.out_channels,
538538
up_block_types=self.config.up_block_types,
@@ -572,7 +572,7 @@ def encode(self, sample, deterministic: bool = True, return_dict: bool = True):
572572

573573
hidden_states = self.encoder(sample, deterministic=deterministic)
574574
moments = self.quant_conv(hidden_states)
575-
posterior = DiagonalGaussianDistribution(moments)
575+
posterior = FlaxDiagonalGaussianDistribution(moments)
576576

577577
if not return_dict:
578578
return (posterior,)

0 commit comments

Comments
 (0)