@@ -559,7 +559,7 @@ def setup(self):
559
559
560
560
def init_weights (self , rng : jax .random .PRNGKey ) -> FrozenDict :
561
561
# init input tensors
562
- sample_shape = (1 , self .in_channels , self .sample_size , self .sample_size )
562
+ sample_shape = (1 , self .sample_size , self .sample_size , self .in_channels )
563
563
sample = jnp .zeros (sample_shape , dtype = jnp .float32 )
564
564
565
565
params_rng , dropout_rng , gaussian_rng = jax .random .split (rng , 3 )
@@ -568,8 +568,6 @@ def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
568
568
return self .init (rngs , sample )["params" ]
569
569
570
570
def encode (self , sample , deterministic : bool = True , return_dict : bool = True ):
571
- sample = jnp .transpose (sample , (0 , 2 , 3 , 1 ))
572
-
573
571
hidden_states = self .encoder (sample , deterministic = deterministic )
574
572
moments = self .quant_conv (hidden_states )
575
573
posterior = DiagonalGaussianDistribution (moments )
@@ -586,8 +584,6 @@ def decode(self, latents, deterministic: bool = True, return_dict: bool = True):
586
584
hidden_states = self .post_quant_conv (latents )
587
585
hidden_states = self .decoder (hidden_states , deterministic = deterministic )
588
586
589
- hidden_states = jnp .transpose (hidden_states , (0 , 3 , 1 , 2 ))
590
-
591
587
if not return_dict :
592
588
return (hidden_states ,)
593
589
0 commit comments