We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0bb3f04 commit 1facd9fCopy full SHA for 1facd9f
src/diffusers/models/unet_2d_condition_flax.py
@@ -218,7 +218,7 @@ def __call__(
218
timesteps = jnp.array([timesteps], dtype=jnp.int32)
219
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
220
timesteps = timesteps.astype(dtype=jnp.float32)
221
- timesteps = timesteps[None]
+ timesteps = jnp.expand_dims(timesteps, 0)
222
223
t_emb = self.time_proj(timesteps)
224
t_emb = self.time_embedding(t_emb)
0 commit comments