Skip to content

Commit e2cd924

Browse files
py39 does not like | E TypeError: unsupported operand type(s) for |: 'type' and 'EnumMeta' (#3611)
1 parent b213c70 commit e2cd924

File tree

5 files changed

+15
-16
lines changed

5 files changed

+15
-16
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import platform
66
from enum import Enum
7-
from typing import Any, Callable, List, Optional, Sequence, Set
7+
from typing import Any, Callable, List, Optional, Sequence, Set, Union
88

99
import torch
1010
import torch.fx
@@ -170,7 +170,7 @@ def compile(
170170
inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None,
171171
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
172172
kwarg_inputs: Optional[dict[Any, Any]] = None,
173-
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
173+
enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None,
174174
**kwargs: Any,
175175
) -> (
176176
torch.nn.Module | torch.jit.ScriptModule | torch.fx.GraphModule | Callable[..., Any]
@@ -213,7 +213,7 @@ def compile(
213213
"""
214214

215215
input_list = inputs if inputs is not None else []
216-
enabled_precisions_set: Set[dtype | torch.dtype] = (
216+
enabled_precisions_set: Set[Union[torch.dtype, dtype]] = (
217217
enabled_precisions
218218
if enabled_precisions is not None
219219
else _defaults.ENABLED_PRECISIONS
@@ -308,7 +308,7 @@ def cross_compile_for_windows(
308308
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
309309
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
310310
kwarg_inputs: Optional[dict[Any, Any]] = None,
311-
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
311+
enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None,
312312
**kwargs: Any,
313313
) -> None:
314314
"""Compile a PyTorch module using TensorRT in Linux for Inference in Windows
@@ -424,7 +424,7 @@ def convert_method_to_trt_engine(
424424
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
425425
kwarg_inputs: Optional[dict[Any, Any]] = None,
426426
ir: str = "default",
427-
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
427+
enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None,
428428
**kwargs: Any,
429429
) -> bytes:
430430
"""Convert a TorchScript module method to a serialized TensorRT engine

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -993,9 +993,9 @@ def convert_exported_program_to_serialized_trt_engine(
993993
*,
994994
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
995995
kwarg_inputs: Optional[dict[Any, Any]] = None,
996-
enabled_precisions: (
997-
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
998-
) = _defaults.ENABLED_PRECISIONS,
996+
enabled_precisions: Union[
997+
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
998+
] = _defaults.ENABLED_PRECISIONS,
999999
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
10001000
workspace_size: int = _defaults.WORKSPACE_SIZE,
10011001
min_block_size: int = _defaults.MIN_BLOCK_SIZE,

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(
6969
strict: bool = True,
7070
allow_complex_guards_as_runtime_asserts: bool = False,
7171
weight_streaming_budget: Optional[int] = None,
72-
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
72+
enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None,
7373
**kwargs: Any,
7474
) -> None:
7575
"""

py/torch_tensorrt/ts/_compile_spec.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

33
from copy import deepcopy
4-
from typing import Any, Dict, List, Optional, Set
4+
from typing import Any, Dict, List, Optional, Set, Union
55

6+
import tensorrt as trt
67
import torch
78
import torch_tensorrt._C.ts as _ts_C
89
from torch_tensorrt import _C
@@ -13,8 +14,6 @@
1314
from torch_tensorrt.ts._Input import TorchScriptInput
1415
from torch_tensorrt.ts.logging import Level, log
1516

16-
import tensorrt as trt
17-
1817

1918
def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input:
2019
clone = torch.classes.tensorrt._Input()
@@ -310,7 +309,7 @@ def TensorRTCompileSpec(
310309
device: Optional[torch.device | Device] = None,
311310
disable_tf32: bool = False,
312311
sparse_weights: bool = False,
313-
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
312+
enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None,
314313
refit: bool = False,
315314
debug: bool = False,
316315
capability: EngineCapability = EngineCapability.STANDARD,

py/torch_tensorrt/ts/_compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4-
from typing import Any, List, Optional, Sequence, Set, Tuple
4+
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
55

66
import torch
77
import torch_tensorrt._C.ts as _C
@@ -18,7 +18,7 @@ def compile(
1818
device: Device = Device._current_device(),
1919
disable_tf32: bool = False,
2020
sparse_weights: bool = False,
21-
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
21+
enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None,
2222
refit: bool = False,
2323
debug: bool = False,
2424
capability: EngineCapability = EngineCapability.STANDARD,
@@ -172,7 +172,7 @@ def convert_method_to_trt_engine(
172172
device: Device = Device._current_device(),
173173
disable_tf32: bool = False,
174174
sparse_weights: bool = False,
175-
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
175+
enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None,
176176
refit: bool = False,
177177
debug: bool = False,
178178
capability: EngineCapability = EngineCapability.STANDARD,

0 commit comments

Comments
 (0)