Skip to content

Commit 4cad1ae

Browse files
committed
make shapes consistent
- output `img_w x img_h x n_channels` from the VAE
1 parent 9be80f4 commit 4cad1ae

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

src/diffusers/models/vae_flax.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def setup(self):
559559

560560
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
561561
# init input tensors
562-
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
562+
sample_shape = (1, self.sample_size, self.sample_size, self.in_channels)
563563
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
564564

565565
params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3)
@@ -568,8 +568,6 @@ def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
568568
return self.init(rngs, sample)["params"]
569569

570570
def encode(self, sample, deterministic: bool = True, return_dict: bool = True):
571-
sample = jnp.transpose(sample, (0, 2, 3, 1))
572-
573571
hidden_states = self.encoder(sample, deterministic=deterministic)
574572
moments = self.quant_conv(hidden_states)
575573
posterior = DiagonalGaussianDistribution(moments)
@@ -586,8 +584,6 @@ def decode(self, latents, deterministic: bool = True, return_dict: bool = True):
586584
hidden_states = self.post_quant_conv(latents)
587585
hidden_states = self.decoder(hidden_states, deterministic=deterministic)
588586

589-
hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2))
590-
591587
if not return_dict:
592588
return (hidden_states,)
593589

0 commit comments

Comments
 (0)