Skip to content

Commit 934f15e

Browse files
authored
Fix IREE eager backend device string (huggingface#237)
1 parent 38664a4 commit 934f15e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

shark/iree_eager_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ class EagerModeIREELinalgOnTensorsBackend(TorchMLIREagerBackend):
4848

4949
def __init__(self, device: str):
5050
self.torch_device_str = device
51-
self.iree_device_str = IREE_DEVICE_MAP[device]
52-
self.config = ireert.Config(self.iree_device_str)
51+
self.config = ireert.Config(IREE_DEVICE_MAP[device])
52+
self.raw_device_str = device
5353

5454
def get_torch_metadata(
5555
self, tensor: DeviceArray, kwargs: Dict[str, Any]
@@ -71,7 +71,7 @@ def compile(self, imported_module: Module):
7171
"EagerMode",
7272
)
7373
callable, _ = get_iree_compiled_module(
74-
imported_module, self.iree_device_str, func_name=fn_name
74+
imported_module, self.raw_device_str, func_name=fn_name
7575
)
7676
return callable
7777

0 commit comments

Comments
 (0)