Skip to content

Commit 2361906

Browse files
authored
Disable passing of sm_arch to iree-compile CL args by default. (huggingface#253)
* Disable passing of sm_arch to iree-compile CL args by default. * Fix formatting.
1 parent f7f24dc commit 2361906

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

shark/iree_utils/gpu_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@
1616

1717
import iree.runtime as ireert
1818
import ctypes
19+
from shark.parser import shark_args
1920

2021
# Get the default gpu args given the architecture.
2122
def get_iree_gpu_args():
2223
ireert.flags.FUNCTION_INPUT_VALIDATION = False
2324
ireert.flags.parse_flags("--cuda_allow_inline_execution")
2425
# TODO: Give the user_interface to pass the sm_arch.
2526
sm_arch = get_cuda_sm_cc()
26-
if sm_arch in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86"]:
27+
if (
28+
sm_arch in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86"]
29+
) and (shark_args.enable_tf32 == True):
2730
return [
2831
"--iree-hal-cuda-disable-loop-nounroll-wa",
2932
f"--iree-hal-cuda-llvm-target-arch={sm_arch}",

shark/parser.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,10 @@ def dir_file(path):
4747
default="./shark_tmp",
4848
)
4949
parser.add_argument(
50-
"--save_mlir",
50+
"--enable_tf32",
5151
default=False,
5252
action="store_true",
53-
help="Saves input MLIR module to /tmp/ directory.",
54-
)
55-
parser.add_argument(
56-
"--save_vmfb",
57-
default=False,
58-
action="store_true",
59-
help="Saves iree .vmfb module to /tmp/ directory.",
53+
help="Enables TF32 precision calculations on supported GPUs.",
6054
)
6155
parser.add_argument(
6256
"--model_config_path",

0 commit comments

Comments
 (0)