Skip to content

🐛 [Bug] Large Accuracy Issue #3626

Open
@patrick-botco

Description

@patrick-botco

Compiling Qwen2DecoderLayer

there are accuracy issues when compiling Qwen2DecoderLayer (fp32)

  1. F.scaled_dot_product_attention is broken. see below
  2. F.scaled_dot_product_attention aside, there are still accuracy issues.. haven't dug too deep here yet.

tl;dr no issues with the exported model. lowering it through torch_tensorrt and running fp32 inference with the compiled model produces incorrect results. PyTorch sdpa, and sdpa-like variants both produce the same error: compiled mse: 1.04e-04, max diff: 3.26e-01 (see below for more):

  • __myl_MulMaxSubExpSumDivMul_0x69e0487bbf9a2f96d4c2b4b256d1f286 (F.scaled_dot_product_attention)
  • __myl_RepAddMaxSubExpSumDivMul_0x85f8de7438077f8edb4e4c2819d66f3f (_sdpa)

removing these kernels (e.g. _uniform_attn) largely reduces the error: compiled mse: 3.09e-09, max diff: 2.43e-04

but still high (probably other issues too). any thoughts?

torch: 2.5.1
torch_tensorrt: 2.5.0+cu124
tensorrt: 10.3

cc @narendasan @peri044 @lanluo-nvidia 🙏

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt
from torch.profiler import ProfilerActivity, profile
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer

torch.manual_seed(42)
torch.cuda.manual_seed(42)


def _sdpa(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
    _, _, _, E = q.shape
    if scale is None:
        scale = 1.0 / math.sqrt(E)
    q = q * scale
    attn = q @ k.transpose(-2, -1)
    if attn_mask is not None:
        attn += attn_mask
    attn = attn.softmax(dim=-1)
    return attn @ v


def _uniform_attn(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
    attn_scores = q @ k.transpose(-2, -1)
    uniform_weights = torch.ones_like(attn_scores) / attn_scores.shape[-1]
    return uniform_weights @ v


# NOTE: monkey-patching the sdpa function
# 1. _sdpa produces a different kernel, but the error is the same
# 2. _uniform_attn produces a different kernel, and there is no error
# F.scaled_dot_product_attention = _uniform_attn

print(f"sdpa function: {F.scaled_dot_product_attention.__name__}")


class Qwen2LayerWrapper(nn.Module):
    def __init__(self, config, layer_idx=0):
        super().__init__()
        self.layer = Qwen2DecoderLayer(config, layer_idx)

    def forward(self, inputs):
        return self.layer(
            hidden_states=inputs["hidden_states"],
            attention_mask=inputs["attention_mask"],
            position_embeddings=inputs["position_embeddings"],
        )[0]


config = Qwen2Config(
    hidden_size=896,
    intermediate_size=4864,
    num_attention_heads=14,
    num_key_value_heads=2,
    rms_norm_eps=1e-6,
    _attn_implementation="sdpa",
)

model = Qwen2LayerWrapper(config).cuda().eval()

seq_len, batch_size = 120, 1
head_dim = config.hidden_size // config.num_attention_heads

hidden_states = torch.randn(batch_size, seq_len, 896, device="cuda")
attention_mask = torch.ones(batch_size, 1, seq_len, seq_len, device="cuda")

cos = torch.randn(batch_size, seq_len, head_dim, device="cuda") * 1000
sin = torch.randn(batch_size, seq_len, head_dim, device="cuda") * 1000

inputs = {
    "hidden_states": hidden_states,
    "attention_mask": attention_mask,
    "position_embeddings": (cos, sin),
}

with torch.inference_mode():
    original_output = model(inputs)

exported_program = torch.export.export(model, (inputs,))
exported_output = exported_program.module()(inputs)

export_mse = torch.mean((original_output - exported_output) ** 2).item()
max_diff = torch.max(torch.abs(original_output - exported_output)).item()
print(f"export mse: {export_mse:.2e}, max diff: {max_diff:.2e}")

compiled_model = torch_tensorrt.dynamo.compile(
    exported_program,
    inputs=[inputs],
    enabled_precisions={torch.float},
)

with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
    with torch.inference_mode():
        compiled_output = compiled_model(inputs)

for event in prof.events():
    if event.device_type == torch.profiler.DeviceType.CUDA and hasattr(event, "name"):
        print(f"  {event.name}")

compiled_mse = torch.mean((original_output - compiled_output) ** 2).item()
max_diff = torch.max(torch.abs(original_output - compiled_output)).item()
print(f"\ncompiled mse: {compiled_mse:.2e}, max diff: {max_diff:.2e}")

run the above - monkey patching with various attention implementations:

sdpa function: scaled_dot_product_attention
export mse: 0.00e+00, max diff: 0.00e+00
  __myl_MovMulMeaAddSqr_0x88159ac0c686f6ec9db477bbf7b4210b
  __myl_DivRepMulMul_0x3abe8c23c8a3f1deb0eb343034c5df68
  sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x32_cgasize1x1x1_warpgroupsize1x1x1_beta0_execute_segment_k_off_kernel_trt
  __myl_ResTraResRep_0xa6fbb435a519bc9882c94d2890c02b74
  __myl_ResTraSliNegSliConRepRepMulMulAdd_0x70c8aa5401751de45fba3a553adfc607
  __myl_ResTraSliNegSliConRepRepMulMulAddResRepResTra_0xb58d3f3ce414b4ca89146687fcb73b0f
  sm80_xmma_gemm_f32f32_tf32f32_f32_nn_n_tilesize64x32x64_stage5_warpsize2x2x1_tensor16x8x8_execute_kernel_trt
  __myl_MulMaxSubExpSumDivMul_0x69e0487bbf9a2f96d4c2b4b256d1f286
  sm80_xmma_gemm_f32f32_tf32f32_f32_nn_n_tilesize32x32x64_stage3_warpsize2x1x2_tensor16x8x8_execute_kernel_trt
  sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x32_cgasize1x1x1_warpgroupsize1x1x1_beta0_execute_segment_k_off_kernel_trt
  __myl_AddMulMea_0x6a8512bf16b2ea049a964503a77d5594
  __myl_AddSqrDivRepMulMul_0x379e100d7b406e5b5f69f30dab749489
  sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x256x32_warpgroupsize1x1x1_beta0_execute_segment_k_off_kernel_trt
  __myl_ResResNegExpAddDivMulMul_0x2b720048b807e670aa820340c81211d7
  sm80_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x64_stage4_warpsize2x1x2_tensor16x8x8_execute_kernel_trt
  sm80_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x64_stage4_warpsize2x1x2_tensor16x8x8_execute_split_k_kernel_trt
  __myl_Add_0xeb2b2d5cab2ae54f6f4c578e29318964

compiled mse: 1.04e-04, max diff: 3.26e-01
sdpa function: _sdpa
export mse: 0.00e+00, max diff: 0.00e+00
  __myl_MovMulMeaAddSqr_0x88159ac0c686f6ec9db477bbf7b4210b
  __myl_DivRepMulMul_0x3abe8c23c8a3f1deb0eb343034c5df68
  sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x32_cgasize1x1x1_warpgroupsize1x1x1_beta0_execute_segment_k_off_kernel_trt
  __myl_ResTraResRep_0xa6fbb435a519bc9882c94d2890c02b74
  __myl_ResTraSliNegSliConRepRepMulMulAddResRepResTra_0xb58d3f3ce414b4ca89146687fcb73b0f
  __myl_ResTraSliNegSliConRepRepMulMulAdd_0x70c8aa5401751de45fba3a553adfc607
  sm80_xmma_gemm_f32f32_tf32f32_f32_nn_n_tilesize64x32x64_stage5_warpsize2x2x1_tensor16x8x8_execute_kernel_trt
  __myl_RepAddMaxSubExpSumDivMul_0x85f8de7438077f8edb4e4c2819d66f3f
  sm80_xmma_gemm_f32f32_tf32f32_f32_nn_n_tilesize32x32x64_stage3_warpsize2x1x2_tensor16x8x8_execute_kernel_trt
  sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x32_cgasize1x1x1_warpgroupsize1x1x1_beta0_execute_segment_k_off_kernel_trt
  __myl_AddMulMea_0x6a8512bf16b2ea049a964503a77d5594
  __myl_AddSqrDivRepMulMul_0x379e100d7b406e5b5f69f30dab749489
  sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x256x32_warpgroupsize1x1x1_beta0_execute_segment_k_off_kernel_trt
  __myl_ResResNegExpAddDivMulMul_0x2b720048b807e670aa820340c81211d7
  sm80_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x64_stage4_warpsize2x1x2_tensor16x8x8_execute_kernel_trt
  sm80_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x64_stage4_warpsize2x1x2_tensor16x8x8_execute_split_k_kernel_trt
  __myl_Add_0xeb2b2d5cab2ae54f6f4c578e29318964

compiled mse: 1.04e-04, max diff: 3.26e-01
sdpa function: _uniform_attn
export mse: 0.00e+00, max diff: 0.00e+00
  __myl_MovMulMeaAddSqr_0x88159ac0c686f6ec9db477bbf7b4210b
  __myl_DivRepMulMul_0x3abe8c23c8a3f1deb0eb343034c5df68
  sm80_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize32x32x64_stage3_warpsize2x1x2_tensor16x8x8_execute_kernel_trt
  __myl_TraResRep_0x8ee8320c2c0582a91dccdbb655981f23
  sm80_xmma_gemm_f32f32_tf32f32_f32_nn_n_tilesize32x32x64_stage3_warpsize2x1x2_tensor16x8x8_execute_kernel_trt
  sm80_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize32x32x64_stage3_warpsize2x1x2_tensor16x8x8_execute_kernel_trt
  __myl_AddMulMea_0x6a8512bf16b2ea049a964503a77d5594
  __myl_AddSqrDivRepMulMul_0x379e100d7b406e5b5f69f30dab749489
  sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x256x32_warpgroupsize1x1x1_beta0_execute_segment_k_off_kernel_trt
  __myl_ResResNegExpAddDivMulMul_0x2b720048b807e670aa820340c81211d7
  sm80_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x64_stage4_warpsize2x1x2_tensor16x8x8_execute_kernel_trt
  sm80_xmma_gemm_f32f32_tf32f32_f32_tn_n_tilesize64x64x64_stage4_warpsize2x1x2_tensor16x8x8_execute_split_k_kernel_trt
  __myl_Add_0xeb2b2d5cab2ae54f6f4c578e29318964

compiled mse: 3.09e-09, max diff: 2.43e-04

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions