Skip to content

Commit 3c46021

Browse files
authored
Revert "find gsutil on linux (huggingface#557)" (huggingface#560)
This reverts commit bba8646.
1 parent bba8646 commit 3c46021

File tree

20 files changed

+249
-98
lines changed

20 files changed

+249
-98
lines changed

cpp/save_img.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import tensorflow as tf
33
from shark.shark_inference import SharkInference
4+
from shark.shark_downloader import download_tf_model
45

56

67
def load_and_preprocess_image(fname: str):

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ pyinstaller
66
tqdm
77

88
# SHARK Downloader
9-
google-cloud-storage
9+
gsutil
1010

1111
# Testing
1212
pytest

shark/examples/shark_inference/bloom_tank.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from shark.shark_inference import SharkInference
2-
from shark.shark_downloader import download_model
2+
from shark.shark_downloader import download_torch_model
33

4-
mlir_model, func_name, inputs, golden_out = download_model(
5-
"bloom", frontend="torch"
6-
)
4+
mlir_model, func_name, inputs, golden_out = download_torch_model("bloom")
75

86
shark_module = SharkInference(
97
mlir_model, func_name, device="cpu", mlir_dialect="tm_tensor"

shark/examples/shark_inference/minilm_jit.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from shark.shark_inference import SharkInference
2-
from shark.shark_downloader import download_model
2+
from shark.shark_downloader import download_torch_model
33

44

5-
mlir_model, func_name, inputs, golden_out = download_model(
6-
"microsoft/MiniLM-L12-H384-uncased",
7-
frontend="torch",
5+
mlir_model, func_name, inputs, golden_out = download_torch_model(
6+
"microsoft/MiniLM-L12-H384-uncased"
87
)
98

109

shark/examples/shark_inference/resnet50_script.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torchvision import transforms
66
import sys
77
from shark.shark_inference import SharkInference
8-
from shark.shark_downloader import download_model
8+
from shark.shark_downloader import download_torch_model
99

1010

1111
################################## Preprocessing inputs and model ############
@@ -66,9 +66,7 @@ def forward(self, img):
6666

6767

6868
## Can pass any img or input to the forward module.
69-
mlir_model, func_name, inputs, golden_out = download_model(
70-
"resnet50", frontend="torch"
71-
)
69+
mlir_model, func_name, inputs, golden_out = download_torch_model("resnet50")
7270

7371
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
7472
shark_module.compile()

shark/examples/shark_inference/stable_diff_f16.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,10 @@
3737

3838

3939
def fp16_unet():
40-
from shark.shark_downloader import download_model
40+
from shark.shark_downloader import download_torch_model
4141

42-
mlir_model, func_name, inputs, golden_out = download_model(
43-
"stable_diff_f16_18_OCT",
44-
tank_url="gs://shark_tank/prashant_nod",
45-
frontend="torch",
42+
mlir_model, func_name, inputs, golden_out = download_torch_model(
43+
"stable_diff_f16_18_OCT", tank_url="gs://shark_tank/prashant_nod"
4644
)
4745
shark_module = SharkInference(
4846
mlir_model, func_name, device=args.device, mlir_dialect="linalg"

shark/examples/shark_inference/stable_diff_tf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818

1919
from shark.shark_inference import SharkInference
20-
from shark.shark_downloader import download_model
20+
from shark.shark_downloader import download_tf_model
2121
from PIL import Image
2222

2323
# pip install "git+https://github.com/keras-team/keras-cv.git"
@@ -75,8 +75,8 @@ def __init__(self, device="cpu", jit_compile=True):
7575
# Create models
7676
self.text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
7777

78-
mlir_model, func_name, inputs, golden_out = download_model(
79-
"stable_diff", tank_url="gs://shark_tank/quinn", frontend="tf"
78+
mlir_model, func_name, inputs, golden_out = download_tf_model(
79+
"stable_diff", tank_url="gs://shark_tank/quinn"
8080
)
8181
shark_module = SharkInference(
8282
mlir_model, func_name, device=device, mlir_dialect="mhlo"

shark/examples/shark_inference/stable_diffusion/utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,10 @@ def _compile_module(shark_module, model_name, extra_args=[]):
3939

4040
# Downloads the model from shark_tank and returns the shark_module.
4141
def get_shark_model(tank_url, model_name, extra_args=[]):
42-
from shark.shark_downloader import download_model
42+
from shark.shark_downloader import download_torch_model
4343

44-
mlir_model, func_name, inputs, golden_out = download_model(
45-
model_name,
46-
tank_url=tank_url,
47-
frontend="torch",
44+
mlir_model, func_name, inputs, golden_out = download_torch_model(
45+
model_name, tank_url=tank_url
4846
)
4947
shark_module = SharkInference(
5048
mlir_model, func_name, device=args.device, mlir_dialect="linalg"

shark/examples/shark_inference/v_diffusion.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from shark.shark_inference import SharkInference
2-
from shark.shark_downloader import download_model
2+
from shark.shark_downloader import download_torch_model
33

44

5-
mlir_model, func_name, inputs, golden_out = download_model(
6-
"v_diffusion", frontend="torch"
7-
)
5+
mlir_model, func_name, inputs, golden_out = download_torch_model("v_diffusion")
86

97
shark_module = SharkInference(
108
mlir_model, func_name, device="vulkan", mlir_dialect="linalg"

shark/shark_downloader.py

Lines changed: 180 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,17 @@
1717
import sys
1818
from pathlib import Path
1919
from shark.parser import shark_args
20-
from google.cloud import storage
2120

2221

23-
def download_public_file(full_gs_url, destination_file_name):
24-
"""Downloads a public blob from the bucket."""
25-
# bucket_name = "gs://your-bucket-name/path/to/file"
26-
# destination_file_name = "local/path/to/file"
27-
28-
storage_client = storage.Client.create_anonymous_client()
29-
bucket_name = full_gs_url.split("/")[2]
30-
source_blob_name = "/".join(full_gs_url.split("/")[3:])
31-
bucket = storage_client.bucket(bucket_name)
32-
blob = bucket.blob(source_blob_name)
33-
blob.download_to_filename(destination_file_name)
22+
def resource_path(relative_path):
23+
"""Get absolute path to resource, works for dev and for PyInstaller"""
24+
base_path = getattr(
25+
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
26+
)
27+
return os.path.join(base_path, relative_path)
3428

3529

30+
GSUTIL_PATH = resource_path("gsutil")
3631
GSUTIL_FLAGS = ' -o "GSUtil:parallel_process_count=1" -m cp -r '
3732

3833

@@ -103,23 +98,103 @@ def check_dir_exists(model_name, frontend="torch", dynamic=""):
10398

10499

105100
# Downloads the torch model from gs://shark_tank dir.
106-
def download_model(
107-
model_name,
108-
dynamic=False,
109-
tank_url="gs://shark_tank/latest",
110-
frontend=None,
111-
tuned=None,
101+
def download_torch_model(
102+
model_name, dynamic=False, tank_url="gs://shark_tank/latest"
112103
):
113104
model_name = model_name.replace("/", "_")
114105
dyn_str = "_dynamic" if dynamic else ""
115106
os.makedirs(WORKDIR, exist_ok=True)
116-
model_dir_name = model_name + "_" + frontend
117-
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
107+
model_dir_name = model_name + "_torch"
108+
109+
def gs_download_model():
110+
gs_command = (
111+
GSUTIL_PATH
112+
+ GSUTIL_FLAGS
113+
+ tank_url
114+
+ "/"
115+
+ model_dir_name
116+
+ ' "'
117+
+ WORKDIR
118+
+ '"'
119+
)
120+
if os.system(gs_command) != 0:
121+
raise Exception("model not present in the tank. Contact Nod Admin")
122+
123+
if not check_dir_exists(model_dir_name, frontend="torch", dynamic=dyn_str):
124+
gs_download_model()
125+
else:
126+
if not _internet_connected():
127+
print(
128+
"No internet connection. Using the model already present in the tank."
129+
)
130+
else:
131+
model_dir = os.path.join(WORKDIR, model_dir_name)
132+
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
133+
gs_hash = (
134+
GSUTIL_PATH
135+
+ GSUTIL_FLAGS
136+
+ tank_url
137+
+ "/"
138+
+ model_dir_name
139+
+ "/hash.npy"
140+
+ " "
141+
+ os.path.join(model_dir, "upstream_hash.npy")
142+
)
143+
if os.system(gs_hash) != 0:
144+
raise Exception("hash of the model not present in the tank.")
145+
upstream_hash = str(
146+
np.load(os.path.join(model_dir, "upstream_hash.npy"))
147+
)
148+
if local_hash != upstream_hash:
149+
if shark_args.update_tank == True:
150+
gs_download_model()
151+
else:
152+
print(
153+
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
154+
)
155+
156+
model_dir = os.path.join(WORKDIR, model_dir_name)
157+
with open(
158+
os.path.join(model_dir, model_name + dyn_str + "_torch.mlir"),
159+
mode="rb",
160+
) as f:
161+
mlir_file = f.read()
162+
163+
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
164+
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
165+
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
166+
167+
inputs_tuple = tuple([inputs[key] for key in inputs])
168+
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
169+
return mlir_file, function_name, inputs_tuple, golden_out_tuple
170+
171+
172+
# Downloads the tflite model from gs://shark_tank dir.
173+
def download_tflite_model(
174+
model_name, dynamic=False, tank_url="gs://shark_tank/latest"
175+
):
176+
dyn_str = "_dynamic" if dynamic else ""
177+
os.makedirs(WORKDIR, exist_ok=True)
178+
model_dir_name = model_name + "_tflite"
179+
180+
def gs_download_model():
181+
gs_command = (
182+
GSUTIL_PATH
183+
+ GSUTIL_FLAGS
184+
+ tank_url
185+
+ "/"
186+
+ model_dir_name
187+
+ ' "'
188+
+ WORKDIR
189+
+ '"'
190+
)
191+
if os.system(gs_command) != 0:
192+
raise Exception("model not present in the tank. Contact Nod Admin")
118193

119194
if not check_dir_exists(
120-
model_dir_name, frontend=frontend, dynamic=dyn_str
195+
model_dir_name, frontend="tflite", dynamic=dyn_str
121196
):
122-
download_public_file(full_gs_url, WORKDIR)
197+
gs_download_model()
123198
else:
124199
if not _internet_connected():
125200
print(
@@ -128,34 +203,104 @@ def download_model(
128203
else:
129204
model_dir = os.path.join(WORKDIR, model_dir_name)
130205
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
131-
gs_hash_url = (
132-
tank_url.rstrip("/") + "/" + model_dir_name + "/hash.npy"
206+
gs_hash = (
207+
GSUTIL_PATH
208+
+ GSUTIL_FLAGS
209+
+ tank_url
210+
+ "/"
211+
+ model_dir_name
212+
+ "/hash.npy"
213+
+ " "
214+
+ os.path.join(model_dir, "upstream_hash.npy")
133215
)
134-
download_public_file(
135-
gs_hash_url, os.path.join(model_dir, "upstream_hash.npy")
216+
if os.system(gs_hash) != 0:
217+
raise Exception("hash of the model not present in the tank.")
218+
upstream_hash = str(
219+
np.load(os.path.join(model_dir, "upstream_hash.npy"))
136220
)
221+
if local_hash != upstream_hash:
222+
if shark_args.update_tank == True:
223+
gs_download_model()
224+
else:
225+
print(
226+
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
227+
)
228+
229+
model_dir = os.path.join(WORKDIR, model_dir_name)
230+
with open(
231+
os.path.join(model_dir, model_name + dyn_str + "_tflite.mlir"),
232+
mode="rb",
233+
) as f:
234+
mlir_file = f.read()
235+
236+
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
237+
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
238+
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
239+
240+
inputs_tuple = tuple([inputs[key] for key in inputs])
241+
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
242+
return mlir_file, function_name, inputs_tuple, golden_out_tuple
243+
244+
245+
def download_tf_model(
246+
model_name, tuned=None, tank_url="gs://shark_tank/latest"
247+
):
248+
model_name = model_name.replace("/", "_")
249+
os.makedirs(WORKDIR, exist_ok=True)
250+
model_dir_name = model_name + "_tf"
251+
252+
def gs_download_model():
253+
gs_command = (
254+
GSUTIL_PATH
255+
+ GSUTIL_FLAGS
256+
+ tank_url
257+
+ "/"
258+
+ model_dir_name
259+
+ ' "'
260+
+ WORKDIR
261+
+ '"'
262+
)
263+
if os.system(gs_command) != 0:
264+
raise Exception("model not present in the tank. Contact Nod Admin")
265+
266+
if not check_dir_exists(model_dir_name, frontend="tf"):
267+
gs_download_model()
268+
else:
269+
if not _internet_connected():
270+
print(
271+
"No internet connection. Using the model already present in the tank."
272+
)
273+
else:
274+
model_dir = os.path.join(WORKDIR, model_dir_name)
275+
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
276+
gs_hash = (
277+
GSUTIL_PATH
278+
+ GSUTIL_FLAGS
279+
+ tank_url
280+
+ "/"
281+
+ model_dir_name
282+
+ "/hash.npy"
283+
+ " "
284+
+ os.path.join(model_dir, "upstream_hash.npy")
285+
)
286+
if os.system(gs_hash) != 0:
287+
raise Exception("hash of the model not present in the tank.")
137288
upstream_hash = str(
138289
np.load(os.path.join(model_dir, "upstream_hash.npy"))
139290
)
140291
if local_hash != upstream_hash:
141292
if shark_args.update_tank == True:
142-
download_public_file(full_gs_url, WORKDIR)
293+
gs_download_model()
143294
else:
144295
print(
145296
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
146297
)
147298

148299
model_dir = os.path.join(WORKDIR, model_dir_name)
149-
suffix = (
150-
"_" + frontend + ".mlir"
151-
if tuned is None
152-
else "_" + frontend + "_" + tuned + ".mlir"
153-
)
300+
suffix = "_tf.mlir" if tuned is None else "_tf_" + tuned + ".mlir"
154301
filename = os.path.join(model_dir, model_name + suffix)
155302
if not os.path.isfile(filename):
156-
filename = os.path.join(
157-
model_dir, model_name + "_" + frontend + ".mlir"
158-
)
303+
filename = os.path.join(model_dir, model_name + "_tf.mlir")
159304

160305
with open(filename, mode="rb") as f:
161306
mlir_file = f.read()

0 commit comments

Comments
 (0)