Skip to content

Commit 1facd9f

Browse files
committed
small change
- use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array
1 parent 0bb3f04 commit 1facd9f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/unet_2d_condition_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def __call__(
218218
timesteps = jnp.array([timesteps], dtype=jnp.int32)
219219
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
220220
timesteps = timesteps.astype(dtype=jnp.float32)
221-
timesteps = timesteps[None]
221+
timesteps = jnp.expand_dims(timesteps, 0)
222222

223223
t_emb = self.time_proj(timesteps)
224224
t_emb = self.time_embedding(t_emb)

0 commit comments

Comments
 (0)