File tree Expand file tree Collapse file tree 3 files changed +23
-0
lines changed Expand file tree Collapse file tree 3 files changed +23
-0
lines changed Original file line number Diff line number Diff line change @@ -76,6 +76,11 @@ def check_device_drivers(device):
76
76
return True
77
77
elif device == "cpu" :
78
78
return False
79
+ elif device == "rocm" :
80
+ try :
81
+ subprocess .check_output ("rocminfo" )
82
+ except Exception :
83
+ return True
79
84
# Unknown device.
80
85
else :
81
86
return True
@@ -89,5 +94,7 @@ def device_driver_info(device):
89
94
return "nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
90
95
elif device in ["metal" , "vulkan" ]:
91
96
return "vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution"
97
+ elif device == "rocm" :
98
+ return "rocm info not found. Please install rocm"
92
99
else :
93
100
return f"{ device } is not supported."
Original file line number Diff line number Diff line change @@ -31,6 +31,10 @@ def get_iree_device_args(device):
31
31
from shark .iree_utils .vulkan_utils import get_iree_vulkan_args
32
32
33
33
return get_iree_vulkan_args ()
34
+ if device == "rocm" :
35
+ from shark .iree_utils .gpu_utils import get_iree_rocm_args
36
+
37
+ return get_iree_rocm_args ()
34
38
return []
35
39
36
40
Original file line number Diff line number Diff line change @@ -35,6 +35,18 @@ def get_iree_gpu_args():
35
35
return ["--iree-hal-cuda-disable-loop-nounroll-wa" ]
36
36
37
37
38
+ # Get the default gpu args given the architecture.
39
+ def get_iree_rocm_args ():
40
+ ireert .flags .FUNCTION_INPUT_VALIDATION = False
41
+ # TODO: find a way to get arch from code.
42
+ rocm_arch = "gfx908"
43
+ return [
44
+ f"--iree-rocm-target-chip={ rocm_arch } " ,
45
+ "--iree-rocm-link-bc=true" ,
46
+ "--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode" ,
47
+ ]
48
+
49
+
38
50
# Some constants taken from cuda.h
39
51
CUDA_SUCCESS = 0
40
52
CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16
You can’t perform that action at this time.
0 commit comments