Open
Description
Bug Description
To Reproduce
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, kernel_size=3, stride=2, padding=1)
def forward(self, x):
x1 = self.conv(x)
x1 = F.interpolate(x1, size=(x.shape[2], x.shape[3]), mode="nearest")
x = x1 + x
return x
if __name__ == '__main__':
with torch.no_grad():
device = torch.device("cuda")
net = Net().to(device)
net = net.eval()
dummy_input = torch.randn(1, 3, 64, 64).to(device)
with torch.jit.optimized_execution(False):
traced_net = torch.jit.trace(net, dummy_input)
trt_net = torch_tensorrt.compile(
traced_net,
inputs = [
torch_tensorrt.Input(min_shape=[1, 3, 16, 16], opt_shape=[4, 3, 64, 64], max_shape=[8, 3, 128, 128]),
],
enabled_precisions = {torch.half},
truncate_long_and_double = True,
allow_shape_tensors = True,
workspace_size = 1 << 30 # 1GB
)
print(net(dummy_input).shape)
print(trt_net(dummy_input).shape)
log:
WARNING:torch_tensorrt._compile:Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript
Traceback (most recent call last):
File "xxx.py", line 29, in <module>
trt_net = torch_tensorrt.compile(
File "/home/xxx/miniconda3/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 185, in compile
compiled_ts_module: torch.jit.ScriptModule = torchscript_compile(
File "/home/xxx/miniconda3/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 151, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at ./core/conversion/var/Var_inl.h:62] Expected ivalue->isIntList() to be true but got false
Requested unwrapping of arg IValue assuming it was N3c104ListIlEE however type is Any[]
Expected behavior
support dynamic size as input
Environment
Collecting environment information...
PyTorch version: 2.2.0.dev20230919+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.27.5
Libc version: glibc-2.31
Python version: 3.8.16 (default, Jan 17 2023, 23:13:24) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-83-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 4090
GPU 1: NVIDIA GeForce RTX 4090
Nvidia driver version: 535.86.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
架构: x86_64
CPU 运行模式: 32-bit, 64-bit
字节序: Little Endian
Address sizes: 43 bits physical, 48 bits virtual
CPU: 48
在线 CPU 列表: 0-47
每个核的线程数: 2
每个座的核数: 24
座: 1
NUMA 节点: 1
厂商 ID: AuthenticAMD
CPU 系列: 23
型号: 49
型号名称: AMD Ryzen Threadripper 3960X 24-Core Processor
步进: 0
Frequency boost: enabled
CPU MHz: 2200.000
CPU 最大 MHz: 3800.0000
CPU 最小 MHz: 2200.0000
BogoMIPS: 7585.95
虚拟化: AMD-V
L1d 缓存: 768 KiB
L1i 缓存: 768 KiB
L2 缓存: 12 MiB
L3 缓存: 128 MiB
NUMA 节点0 CPU: 0-47
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
标记: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es
Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.8
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.5
[pip3] pytorch-lightning==1.8.6
[pip3] pytorch-triton==2.1.0+6e4932cda8
[pip3] torch==2.2.0.dev20230919+cu121
[pip3] torch-tensorrt==2.2.0.dev20230919+cu121
[pip3] torchaudio==2.2.0.dev20230919+cu121
[pip3] torchmetrics==1.1.2
[pip3] torchvision==0.17.0.dev20230919+cu121
[pip3] triton==2.0.0
[conda] msgpack-numpy 0.4.8 pypi_0 pypi
[conda] numpy 1.23.5 pypi_0 pypi
[conda] pytorch-lightning 1.8.6 pypi_0 pypi
[conda] pytorch-triton 2.1.0+6e4932cda8 pypi_0 pypi
[conda] torch 2.2.0.dev20230919+cu121 pypi_0 pypi
[conda] torch-tensorrt 2.2.0.dev20230919+cu121 pypi_0 pypi
[conda] torchaudio 2.2.0.dev20230919+cu121 pypi_0 pypi
[conda] torchmetrics 1.1.2 pypi_0 pypi
[conda] torchvision 0.17.0.dev20230919+cu121 pypi_0 pypi
[conda] triton 2.0.0 pypi_0 pypi