7
7
import torch
8
8
import torch_tensorrt
9
9
from torch .testing ._internal .common_utils import TestCase
10
- from torch_tensorrt .dynamo .utils import get_model_device
11
10
from torch_tensorrt ._utils import check_cross_compile_trt_win_lib
11
+ from torch_tensorrt .dynamo .utils import get_model_device
12
12
13
13
from ..testing_utilities import DECIMALS_OF_AGREEMENT
14
14
@@ -85,8 +85,12 @@ def forward(self, a, b):
85
85
@pytest .mark .unit
86
86
def test_dynamo_cross_compile_for_windows_cpu_offload (self ):
87
87
class Add (torch .nn .Module ):
88
+ def __init__ (self ):
89
+ super ().__init__ ()
90
+ self .linear = torch .nn .Linear (3 , 3 )
91
+
88
92
def forward (self , a , b ):
89
- return torch .add (a , b )
93
+ return torch .add (self . linear ( a ) , b )
90
94
91
95
model = Add ().eval ().cuda ()
92
96
inputs = (torch .randn (2 , 3 ).cuda (), torch .randn (2 , 3 ).cuda ())
@@ -101,7 +105,7 @@ def forward(self, a, b):
101
105
trt_gm = torch_tensorrt .dynamo .cross_compile_for_windows (
102
106
exp_program , ** compile_spec
103
107
)
104
- assert get_model_device (trt_gm ).type == "cpu"
108
+ assert get_model_device (model ).type == "cpu"
105
109
torch_tensorrt .dynamo .save_cross_compiled_exported_program (
106
110
trt_gm , file_path = trt_ep_path
107
111
)
@@ -112,6 +116,10 @@ def forward(self, a, b):
112
116
platform .system () != "Linux" or platform .architecture ()[0 ] != "64bit" ,
113
117
"Cross compile for windows can only be enabled on linux x86-64 platform" ,
114
118
)
119
+ @unittest .skipIf (
120
+ not (check_cross_compile_trt_win_lib ()),
121
+ "TRT windows lib for cross compile not found" ,
122
+ )
115
123
@pytest .mark .unit
116
124
def test_dynamo_cross_compile_for_windows_multiple_output (self ):
117
125
class Add (torch .nn .Module ):
0 commit comments