Skip to content

Commit fc8aa6a

Browse files
authored
Add ROCM parameters (huggingface#335)
1 parent 9bd951b commit fc8aa6a

File tree

3 files changed

+23
-0
lines changed

3 files changed

+23
-0
lines changed

shark/iree_utils/_common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ def check_device_drivers(device):
7676
return True
7777
elif device == "cpu":
7878
return False
79+
elif device == "rocm":
80+
try:
81+
subprocess.check_output("rocminfo")
82+
except Exception:
83+
return True
7984
# Unknown device.
8085
else:
8186
return True
@@ -89,5 +94,7 @@ def device_driver_info(device):
8994
return "nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
9095
elif device in ["metal", "vulkan"]:
9196
return "vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution"
97+
elif device == "rocm":
98+
return "rocm info not found. Please install rocm"
9299
else:
93100
return f"{device} is not supported."

shark/iree_utils/compile_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ def get_iree_device_args(device):
3131
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
3232

3333
return get_iree_vulkan_args()
34+
if device == "rocm":
35+
from shark.iree_utils.gpu_utils import get_iree_rocm_args
36+
37+
return get_iree_rocm_args()
3438
return []
3539

3640

shark/iree_utils/gpu_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ def get_iree_gpu_args():
3535
return ["--iree-hal-cuda-disable-loop-nounroll-wa"]
3636

3737

38+
# Get the default gpu args given the architecture.
39+
def get_iree_rocm_args():
40+
ireert.flags.FUNCTION_INPUT_VALIDATION = False
41+
# TODO: find a way to get arch from code.
42+
rocm_arch = "gfx908"
43+
return [
44+
f"--iree-rocm-target-chip={rocm_arch}",
45+
"--iree-rocm-link-bc=true",
46+
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
47+
]
48+
49+
3850
# Some constants taken from cuda.h
3951
CUDA_SUCCESS = 0
4052
CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16

0 commit comments

Comments
 (0)