Skip to content

Commit 7fc53b5

Browse files
authored
Fix dimensionalities in apply_rotary_emb functions' comments (#11717)
Fix dimensionality in `apply_rotary_emb` functions' comments.
1 parent 0874dd0 commit 7fc53b5

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/diffusers/models/embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,11 +1199,11 @@ def apply_rotary_emb(
11991199

12001200
if use_real_unbind_dim == -1:
12011201
# Used for flux, cogvideox, hunyuan-dit
1202-
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
1202+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
12031203
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
12041204
elif use_real_unbind_dim == -2:
12051205
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
1206-
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
1206+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
12071207
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
12081208
else:
12091209
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def forward(
481481

482482
def apply_rotary_emb(x, freqs):
483483
cos, sin = freqs
484-
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
484+
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
485485
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
486486
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
487487
return out

0 commit comments

Comments
 (0)