Skip to content

Optimize RoPEAttention implementation for onnx export #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 35 additions & 50 deletions sam2/modeling/position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,62 +163,47 @@ def forward_with_coords(
# 2. https://github.com/naver-ai/rope-vit
# 3. https://github.com/lucidrains/rotary-embedding-torch


def init_t_xy(end_x: int, end_y: int):
t = torch.arange(end_x * end_y, dtype=torch.float32)
def init_t_xy(end_x: int, end_y: int, device=None):
t = torch.arange(end_x * end_y, dtype=torch.float32, device=device)
t_x = (t % end_x).float()
t_y = torch.div(t, end_x, rounding_mode="floor").float()
return t_x, t_y


def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))

def compute_axial_rope_cos_sin(dim: int, end_x: int, end_y: int, theta: float = 10000.):
# dim: 需要能被4整除
assert dim % 2 == 0
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t_x, t_y = init_t_xy(end_x, end_y)
freqs_x = torch.outer(t_x, freqs_x)
freqs_y = torch.outer(t_y, freqs_y)
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
freqs_x = torch.outer(t_x, freqs) # [end_x*end_y, dim//2]
freqs_y = torch.outer(t_y, freqs) # [end_x*end_y, dim//2]
# 拼在一起
freqs = torch.cat([freqs_x, freqs_y], dim=-1) # [seq_len, dim]
cos = freqs.cos()
sin = freqs.sin()
return cos, sin # [seq_len, dim]

def reshape_for_broadcast(param, x):
# param: [seq_len, dim], x: [..., seq_len, dim]
# reshape param为[1, ..., seq_len, dim]
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


def apply_rotary_enc(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
repeat_freqs_k: bool = False,
):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = (
torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
if xk.shape[-2] != 0
else None
)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
if xk_ is None:
# no keys to rotate, due to dropout
return xq_out.type_as(xq).to(xq.device), xk
# repeat freqs along seq_len dim to match k seq_len
if repeat_freqs_k:
r = xk_.shape[-2] // xq_.shape[-2]
if freqs_cis.is_cuda:
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
else:
# torch.repeat on complex numbers may not be supported on non-CUDA devices
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
shape = [1]*(ndim-2) + list(param.shape)
return param.view(*shape)

def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
# x: [B, n_head, seq_len, head_dim] head_dim必须是偶数
# cos/sin: [seq_len, head_dim]
# 广播到x形状
while cos.ndim < x.ndim:
cos = cos.unsqueeze(0)
sin = sin.unsqueeze(0)
x1, x2 = x[..., ::2], x[..., 1::2]
cos, sin = cos[..., ::2], sin[..., ::2]
x_out_even = x1 * cos - x2 * sin
x_out_odd = x1 * sin + x2 * cos
x_out = torch.stack([x_out_even, x_out_odd], dim=-1)
x_out = x_out.flatten(-2) # 恢复到原始最后一维
return x_out



# Matrix version of rotary enc
Expand Down
106 changes: 37 additions & 69 deletions sam2/modeling/sam/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
import torch.nn.functional as F
from torch import nn, Tensor

from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
from sam2.modeling.position_encoding import apply_rotary_matenc, get_rotation_matrices
from sam2.modeling.position_encoding import compute_axial_rope_cos_sin, apply_rotary_emb
from sam2.modeling.sam2_utils import MLP
from sam2.utils.misc import get_sdpa_settings

Expand Down Expand Up @@ -291,95 +290,64 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:


class RoPEAttention(Attention):
"""Attention with rotary position encoding."""

def __init__(
self,
*args,
rope_theta=10000.0,
# whether to repeat q rope to match k length
# this is needed for cross-attention to memories
rope_k_repeat=False,
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
feat_sizes=(64, 64),
**kwargs,
):
super().__init__(*args, **kwargs)

self.compute_cis = partial(
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
)
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
self.freqs_cis = freqs_cis
self.rope_theta = rope_theta
self.feat_sizes = feat_sizes
self.rope_k_repeat = rope_k_repeat
self._cached_shape = None
self._cached_cos_sin = None

def get_cos_sin(self, seq_len, device, dtype):
# 缓存
if self._cached_shape == (seq_len, device, dtype):
return self._cached_cos_sin
w = h = int(math.sqrt(seq_len))
cos, sin = compute_axial_rope_cos_sin(
dim=self.internal_dim // self.num_heads,
end_x=w, end_y=h, theta=self.rope_theta
)
cos, sin = cos.to(device=device, dtype=dtype), sin.to(device=device, dtype=dtype)
self._cached_shape = (seq_len, device, dtype)
self._cached_cos_sin = (cos, sin)
return cos, sin

if USE_MAT_ROTARY_ENC:
rotmats = get_rotation_matrices(dim=self.internal_dim // self.num_heads, end_x=feat_sizes[0], end_y=feat_sizes[1], theta=rope_theta)
self.rotmats = rotmats
self.rope_theta = rope_theta

def forward(
self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
) -> Tensor:
# Input projections
def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 64) -> Tensor:
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)

# Separate into heads
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)

# Apply rotary position encoding
w = h = math.sqrt(q.shape[-2])

self.freqs_cis = self.freqs_cis.to(q.device)
if self.freqs_cis.shape[0] != q.shape[-2]:
self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)

if USE_MAT_ROTARY_ENC:
self.rotmats = self.rotmats.to(q.device)
if self.rotmats.shape[0] != q.shape[-2]:
self.rotmats = get_rotation_matrices(dim=self.internal_dim // self.num_heads, end_x=w, end_y=h, theta=self.rope_theta)

if q.shape[-2] != k.shape[-2]:
seq_len = q.shape[-2]
cos, sin = self.get_cos_sin(seq_len, q.device, q.dtype)
# q: [B, n_head, seq_len, head_dim]
q = apply_rotary_emb(q, cos, sin)
if k.shape[-2] != q.shape[-2]:
assert self.rope_k_repeat

# repeat cos/sin
repeat = k.shape[-2] // q.shape[-2]
cos_k = cos.repeat(repeat, 1)
sin_k = sin.repeat(repeat, 1)
else:
cos_k, sin_k = cos, sin
num_k_rope = k.size(-2) - num_k_exclude_rope
if USE_MAT_ROTARY_ENC:
q, k[:, :, :num_k_rope] = apply_rotary_matenc(
q,
k[:, :, :num_k_rope],
rotmats=self.rotmats,
repeat_freqs_k=self.rope_k_repeat,
)
if num_k_exclude_rope > 0:
k_rope = apply_rotary_emb(k[:, :, :num_k_rope], cos_k[:num_k_rope], sin_k[:num_k_rope])
k = torch.cat([k_rope, k[:, :, num_k_rope:]], dim=-2)
else:
q, k[:, :, :num_k_rope] = apply_rotary_enc(
q,
k[:, :, :num_k_rope],
freqs_cis=self.freqs_cis,
repeat_freqs_k=self.rope_k_repeat,
)
k = apply_rotary_emb(k, cos_k, sin_k)

dropout_p = self.dropout_p if self.training else 0.0
# Attention
#try:
# with sdp_kernel_context(dropout_p):
# out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
#except Exception as e:
if True:
# Fall back to all kernels if the Flash attention kernel fails
#warnings.warn(
# f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
# f"kernels for scaled_dot_product_attention (which may have a slower speed).",
# category=UserWarning,
# stacklevel=2,
#)
global ALLOW_ALL_KERNELS
ALLOW_ALL_KERNELS = True
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)

out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out)
out = self.out_proj(out)

return out