Open
Description
Bug Description
self = <dynamo.models.test_dtype_support.TestBF16Support testMethod=test_bf16_cpp>
@unittest.skipIf(
not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime,
"Torch-TensorRT Runtime is not available",
)
def test_bf16_cpp(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()
def forward(self, x):
out = self.conv(x)
out = self.relu(out)
return out
in_tensor = torch.randn((1, 3, 224, 224), device="cuda", dtype=torch.bfloat16)
mod = MyModule().to(torch.device("cuda")).to(torch.bfloat16)
exp_mod = torch.export.export(mod, (in_tensor,))
trt_mod = torch_tensorrt.dynamo.compile(
exp_mod,
inputs=[in_tensor],
pass_through_build_failures=True,
enabled_precisions={torch.float, torch.bfloat16, torch.half},
min_block_size=1,
use_python_runtime=False,
cache_built_engines=False,
reuse_cached_engines=False,
)
torch_model_results = mod(in_tensor)
optimized_model_results = trt_mod(in_tensor)
max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
> self.assertAlmostEqual(
max_diff,
0,
delta=3e-2,
msg=f"Torch outputs and TRT outputs don't match close enough.",
)
E AssertionError: inf != 0 within 0.03 delta (inf difference) : Torch outputs and TRT outputs don't match close enough.
E
E To execute this test, run the following from the base repo dir:
E python test_dtype_support.py TestBF16Support.test_bf16_cpp
E
E This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
models/test_dtype_support.py:234: AssertionError
To Reproduce
Steps to reproduce the behavior:
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0):
- PyTorch Version (e.g. 1.0):
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda
,pip
,libtorch
, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version:
- CUDA version:
- GPU models and configuration:
- Any other relevant information: