Open
Description
Compiling Qwen2DecoderLayer
there are accuracy issues when compiling Qwen2DecoderLayer
(fp32)
F.scaled_dot_product_attention
is broken. see belowF.scaled_dot_product_attention
aside, there are still accuracy issues.. haven't dug too deep here yet.
tl;dr no issues with the exported model. lowering it through torch_tensorrt
and running fp32 inference with the compiled model produces incorrect results. PyTorch sdpa, and sdpa-like variants both produce the same error: compiled mse: 1.04e-04, max diff: 3.26e-01
(see below for more):
__myl_MulMaxSubExpSumDivMul_0x69e0487bbf9a2f96d4c2b4b256d1f286
(F.scaled_dot_product_attention
)__myl_RepAddMaxSubExpSumDivMul_0x85f8de7438077f8edb4e4c2819d66f3f
(_sdpa
)
removing these kernels (e.g. _uniform_attn
) largely reduces the error: compiled mse: 3.09e-09, max diff: 2.43e-04
but still high (probably other issues too). any thoughts?
torch: 2.5.1
torch_tensorrt: 2.5.0+cu124
tensorrt: 10.3
cc @narendasan @peri044 @lanluo-nvidia 🙏
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt
from torch.profiler import ProfilerActivity, profile
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
torch.manual_seed(42)
torch.cuda.manual_seed(42)
def _sdpa(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
_, _, _, E = q.shape
if scale is None:
scale = 1.0 / math.sqrt(E)
q = q * scale
attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
attn += attn_mask
attn = attn.softmax(dim=-1)
return attn @ v
def _uniform_attn(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
attn_scores = q @ k.transpose(-2, -1)
uniform_weights = torch.ones_like(attn_scores) / attn_scores.shape[-1]
return uniform_weights @ v
# NOTE: monkey-patching the sdpa function
# 1. _sdpa produces a different kernel, but the error is the same
# 2. _uniform_attn produces a different kernel, and there is no error
# F.scaled_dot_product_attention = _uniform_attn
print(f"sdpa function: {F.scaled_dot_product_attention.__name__}")
class Qwen2LayerWrapper(nn.Module):
def __init__(self, config, layer_idx=0):
super().__init__()
self.layer = Qwen2DecoderLayer(config, layer_idx)
def forward(self, inputs):
return self.layer(
hidden_states=inputs["hidden_states"],
attention_mask=inputs["attention_mask"],
position_embeddings=inputs["position_embeddings"],
)[0]
config = Qwen2Config(
hidden_size=896,
intermediate_size=4864,
num_attention_heads=14,
num_key_value_heads=2,
rms_norm_eps=1e-6,
_attn_implementation="sdpa",
)
model = Qwen2LayerWrapper(config).cuda().eval()
seq_len, batch_size = 120, 1
head_dim = config.hidden_size // config.num_attention_heads
hidden_states = torch.randn(batch_size, seq_len, 896, device="cuda")
attention_mask = torch.ones(batch_size, 1, seq_len, seq_len, device="cuda")
cos = torch.randn(batch_size, seq_len, head_dim, device="cuda") * 1000
sin = torch.randn(batch_size, seq_len, head_dim, device="cuda") * 1000
inputs = {
"hidden_states": hidden_states,
"attention_mask": attention_mask,
"position_embeddings": (cos, sin),
}
with torch.inference_mode():
original_output = model(inputs)
exported_program = torch.export.export(model, (inputs,))
exported_output = exported_program.module()(inputs)
export_mse = torch.mean((original_output - exported_output) ** 2).item()
max_diff = torch.max(torch.abs(original_output - exported_output)).item()
print(f"export mse: {export_mse:.2e}, max diff: {max_diff:.2e}")
compiled_model = torch_tensorrt.dynamo.compile(
exported_program,
inputs=[inputs],
enabled_precisions={torch.float},
)
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
with torch.inference_mode():
compiled_output = compiled_model(inputs)
for event in prof.events():
if event.device_type == torch.profiler.DeviceType.CUDA and hasattr(event, "name"):
print(f" {event.name}")
compiled_mse = torch.mean((original_output - compiled_output) ** 2).item()
max_diff = torch.max(torch.abs(original_output - compiled_output)).item()
print(f"\ncompiled mse: {compiled_mse:.2e}, max diff: {max_diff:.2e}")
run the above - monkey patching with various attention implementations:
sdpa function: scaled_dot_product_attention
export mse: 0.00e+00, max diff: 0.00e+00
__myl_MovMulMeaAddSqr_0x88159ac0c686f6ec9db477bbf7b4210b
__myl_DivRepMulMul_0x3abe8c23c8a3f1deb0eb343034c5df68
sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x32_cgasize1x1x1_warpgroupsize1x1x1_beta0_execute_segment_k_off_kernel_trt
__myl_ResTraResRep_0xa6fbb435a519bc9882c94d2890c02b74
__myl_ResTraSliNegSliConRepRepMulMulAdd_0x70c8aa5401751de45fba3a553adfc607
__myl_ResTraSliNegSliConRepRepMulMulAddResRepResTra_0xb58d3f3ce414b4ca89146687fcb73b0f
sm80_xmma_gemm_f32f32_tf32f32_f32_nn_n_tilesize64x32x64_stage5_warpsize2x2x1_tensor16x8x8_execute_kernel_trt
__myl_MulMaxSubExpSumDivMul_0x69e0487bbf9a2f96d4c2b4b256d1f286
sm80_xmma_gemm_f32f32_tf32f32_f32_nn_n_tilesize32x32x64_stage3_warpsize2x1x2_tensor16x8x8_execute_kernel_trt
sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x32_cgasize1x1x1_warpgroupsize1x1x1_beta0_execute_segment_k_off_kernel_trt
__myl_AddMulMea_0x6a8512bf16b2ea049a964503a77d5594
__myl_AddSqrDivRepMulMul_0x379e100d7b406e5b5f69f30dab749489
sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x256x32_warpgroupsize1x1x1_beta0_execute_segment_k_off_kernel_trt
__myl_ResResNegExpAddDivMulMul_0x2b720048b807e670aa820340c81211d7
sm80_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x64_stage4_warpsize2x1x2_tensor16x8x8_execute_kernel_trt
sm80_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x64_stage4_warpsize2x1x2_tensor16x8x8_execute_split_k_kernel_trt
__myl_Add_0xeb2b2d5cab2ae54f6f4c578e29318964
compiled mse: 1.04e-04, max diff: 3.26e-01
sdpa function: _sdpa
export mse: 0.00e+00, max diff: 0.00e+00
__myl_MovMulMeaAddSqr_0x88159ac0c686f6ec9db477bbf7b4210b
__myl_DivRepMulMul_0x3abe8c23c8a3f1deb0eb343034c5df68
sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x32_cgasize1x1x1_warpgroupsize1x1x1_beta0_execute_segment_k_off_kernel_trt
__myl_ResTraResRep_0xa6fbb435a519bc9882c94d2890c02b74
__myl_ResTraSliNegSliConRepRepMulMulAddResRepResTra_0xb58d3f3ce414b4ca89146687fcb73b0f
__myl_ResTraSliNegSliConRepRepMulMulAdd_0x70c8aa5401751de45fba3a553adfc607
sm80_xmma_gemm_f32f32_tf32f32_f32_nn_n_tilesize64x32x64_stage5_warpsize2x2x1_tensor16x8x8_execute_kernel_trt
__myl_RepAddMaxSubExpSumDivMul_0x85f8de7438077f8edb4e4c2819d66f3f
sm80_xmma_gemm_f32f32_tf32f32_f32_nn_n_tilesize32x32x64_stage3_warpsize2x1x2_tensor16x8x8_execute_kernel_trt
sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x32_cgasize1x1x1_warpgroupsize1x1x1_beta0_execute_segment_k_off_kernel_trt
__myl_AddMulMea_0x6a8512bf16b2ea049a964503a77d5594
__myl_AddSqrDivRepMulMul_0x379e100d7b406e5b5f69f30dab749489
sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x256x32_warpgroupsize1x1x1_beta0_execute_segment_k_off_kernel_trt
__myl_ResResNegExpAddDivMulMul_0x2b720048b807e670aa820340c81211d7
sm80_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x64_stage4_warpsize2x1x2_tensor16x8x8_execute_kernel_trt
sm80_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x64_stage4_warpsize2x1x2_tensor16x8x8_execute_split_k_kernel_trt
__myl_Add_0xeb2b2d5cab2ae54f6f4c578e29318964
compiled mse: 1.04e-04, max diff: 3.26e-01
sdpa function: _uniform_attn
export mse: 0.00e+00, max diff: 0.00e+00
__myl_MovMulMeaAddSqr_0x88159ac0c686f6ec9db477bbf7b4210b
__myl_DivRepMulMul_0x3abe8c23c8a3f1deb0eb343034c5df68
sm80_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize32x32x64_stage3_warpsize2x1x2_tensor16x8x8_execute_kernel_trt
__myl_TraResRep_0x8ee8320c2c0582a91dccdbb655981f23
sm80_xmma_gemm_f32f32_tf32f32_f32_nn_n_tilesize32x32x64_stage3_warpsize2x1x2_tensor16x8x8_execute_kernel_trt
sm80_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize32x32x64_stage3_warpsize2x1x2_tensor16x8x8_execute_kernel_trt
__myl_AddMulMea_0x6a8512bf16b2ea049a964503a77d5594
__myl_AddSqrDivRepMulMul_0x379e100d7b406e5b5f69f30dab749489
sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x256x32_warpgroupsize1x1x1_beta0_execute_segment_k_off_kernel_trt
__myl_ResResNegExpAddDivMulMul_0x2b720048b807e670aa820340c81211d7
sm80_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x64_stage4_warpsize2x1x2_tensor16x8x8_execute_kernel_trt
sm80_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x64_stage4_warpsize2x1x2_tensor16x8x8_execute_split_k_kernel_trt
__myl_Add_0xeb2b2d5cab2ae54f6f4c578e29318964
compiled mse: 3.09e-09, max diff: 2.43e-04