Skip to content

Commit 29f59f9

Browse files
fix cross compilation test bug (#3609)
1 parent e2cd924 commit 29f59f9

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

py/torch_tensorrt/_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ def check_cross_compile_trt_win_lib() -> bool:
1616
# cross compile feature is only available on linux
1717
# build engine on linux and run on windows
1818
if sys.platform.startswith("linux"):
19+
import re
20+
1921
import dllist
2022

2123
loaded_libs = dllist.dllist()
22-
target_lib = "libnvinfer_builder_resource_win.so.*"
23-
if target_lib in loaded_libs:
24-
return True
24+
target_lib = ".*libnvinfer_builder_resource_win.so.*"
25+
return any(re.match(target_lib, lib) for lib in loaded_libs)
2526
return False

tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import torch
88
import torch_tensorrt
99
from torch.testing._internal.common_utils import TestCase
10-
from torch_tensorrt.dynamo.utils import get_model_device
1110
from torch_tensorrt._utils import check_cross_compile_trt_win_lib
11+
from torch_tensorrt.dynamo.utils import get_model_device
1212

1313
from ..testing_utilities import DECIMALS_OF_AGREEMENT
1414

@@ -85,8 +85,12 @@ def forward(self, a, b):
8585
@pytest.mark.unit
8686
def test_dynamo_cross_compile_for_windows_cpu_offload(self):
8787
class Add(torch.nn.Module):
88+
def __init__(self):
89+
super().__init__()
90+
self.linear = torch.nn.Linear(3, 3)
91+
8892
def forward(self, a, b):
89-
return torch.add(a, b)
93+
return torch.add(self.linear(a), b)
9094

9195
model = Add().eval().cuda()
9296
inputs = (torch.randn(2, 3).cuda(), torch.randn(2, 3).cuda())
@@ -101,7 +105,7 @@ def forward(self, a, b):
101105
trt_gm = torch_tensorrt.dynamo.cross_compile_for_windows(
102106
exp_program, **compile_spec
103107
)
104-
assert get_model_device(trt_gm).type == "cpu"
108+
assert get_model_device(model).type == "cpu"
105109
torch_tensorrt.dynamo.save_cross_compiled_exported_program(
106110
trt_gm, file_path=trt_ep_path
107111
)
@@ -112,6 +116,10 @@ def forward(self, a, b):
112116
platform.system() != "Linux" or platform.architecture()[0] != "64bit",
113117
"Cross compile for windows can only be enabled on linux x86-64 platform",
114118
)
119+
@unittest.skipIf(
120+
not (check_cross_compile_trt_win_lib()),
121+
"TRT windows lib for cross compile not found",
122+
)
115123
@pytest.mark.unit
116124
def test_dynamo_cross_compile_for_windows_multiple_output(self):
117125
class Add(torch.nn.Module):

0 commit comments

Comments
 (0)