Skip to content

Commit e67bcff

Browse files
authored
add vulkan-heap-block-size flag (huggingface#498)
1 parent 005ded3 commit e67bcff

File tree

4 files changed

+28
-1
lines changed

4 files changed

+28
-1
lines changed

shark/examples/shark_inference/stable_diffusion/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from tqdm.auto import tqdm
66
import numpy as np
77
from stable_args import args
8-
from utils import get_shark_model
8+
from utils import get_shark_model, set_iree_runtime_flags
99
from opt_params import get_unet, get_vae, get_clip
1010
import time
1111

@@ -46,6 +46,7 @@ def end_profiling(device):
4646

4747
batch_size = len(prompt)
4848

49+
set_iree_runtime_flags()
4950
unet = get_unet()
5051
vae = get_vae()
5152
clip = get_clip()

shark/examples/shark_inference/stable_diffusion/stable_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,10 @@
9797
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
9898
)
9999

100+
p.add_argument(
101+
"--vulkan_large_heap_block_size",
102+
default="4294967296",
103+
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
104+
)
105+
100106
args = p.parse_args()

shark/examples/shark_inference/stable_diffusion/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from shark.shark_inference import SharkInference
55
from stable_args import args
66
from shark.shark_importer import import_with_fx
7+
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
78

89

910
def _compile_module(shark_module, model_name, extra_args=[]):
@@ -16,6 +17,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
1617
extended_name = "{}_{}".format(model_name, device)
1718
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
1819
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
20+
print(f"loading existing vmfb from: {vmfb_path}")
1921
shark_module.load_module(vmfb_path, extra_args=extra_args)
2022
else:
2123
if args.save_vmfb:
@@ -61,3 +63,14 @@ def compile_through_fx(model, inputs, model_name, extra_args=[]):
6163
)
6264

6365
return _compile_module(shark_module, model_name, extra_args)
66+
67+
68+
def set_iree_runtime_flags():
69+
70+
vulkan_runtime_flags = [
71+
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
72+
]
73+
if "vulkan" in args.device:
74+
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
75+
76+
return

shark/iree_utils/vulkan_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from os import linesep
1818
from shark.iree_utils._common import run_cmd
19+
import iree.runtime as ireert
1920

2021

2122
def get_vulkan_device_name():
@@ -68,3 +69,9 @@ def get_iree_vulkan_args(extra_args=[]):
6869
if vulkan_triple_flag is not None:
6970
vulkan_flag.append(vulkan_triple_flag)
7071
return vulkan_flag
72+
73+
74+
def set_iree_vulkan_runtime_flags(flags):
75+
for flag in flags:
76+
ireert.flags.parse_flags(flag)
77+
return

0 commit comments

Comments
 (0)