17
17
import sys
18
18
from pathlib import Path
19
19
from shark .parser import shark_args
20
- from google .cloud import storage
21
20
22
21
23
- def download_public_file (full_gs_url , destination_file_name ):
24
- """Downloads a public blob from the bucket."""
25
- # bucket_name = "gs://your-bucket-name/path/to/file"
26
- # destination_file_name = "local/path/to/file"
27
-
28
- storage_client = storage .Client .create_anonymous_client ()
29
- bucket_name = full_gs_url .split ("/" )[2 ]
30
- source_blob_name = "/" .join (full_gs_url .split ("/" )[3 :])
31
- bucket = storage_client .bucket (bucket_name )
32
- blob = bucket .blob (source_blob_name )
33
- blob .download_to_filename (destination_file_name )
22
+ def resource_path (relative_path ):
23
+ """Get absolute path to resource, works for dev and for PyInstaller"""
24
+ base_path = getattr (
25
+ sys , "_MEIPASS" , os .path .dirname (os .path .abspath (__file__ ))
26
+ )
27
+ return os .path .join (base_path , relative_path )
34
28
35
29
30
+ GSUTIL_PATH = resource_path ("gsutil" )
36
31
GSUTIL_FLAGS = ' -o "GSUtil:parallel_process_count=1" -m cp -r '
37
32
38
33
@@ -103,23 +98,103 @@ def check_dir_exists(model_name, frontend="torch", dynamic=""):
103
98
104
99
105
100
# Downloads the torch model from gs://shark_tank dir.
106
- def download_model (
107
- model_name ,
108
- dynamic = False ,
109
- tank_url = "gs://shark_tank/latest" ,
110
- frontend = None ,
111
- tuned = None ,
101
+ def download_torch_model (
102
+ model_name , dynamic = False , tank_url = "gs://shark_tank/latest"
112
103
):
113
104
model_name = model_name .replace ("/" , "_" )
114
105
dyn_str = "_dynamic" if dynamic else ""
115
106
os .makedirs (WORKDIR , exist_ok = True )
116
- model_dir_name = model_name + "_" + frontend
117
- full_gs_url = tank_url .rstrip ("/" ) + "/" + model_dir_name
107
+ model_dir_name = model_name + "_torch"
108
+
109
+ def gs_download_model ():
110
+ gs_command = (
111
+ GSUTIL_PATH
112
+ + GSUTIL_FLAGS
113
+ + tank_url
114
+ + "/"
115
+ + model_dir_name
116
+ + ' "'
117
+ + WORKDIR
118
+ + '"'
119
+ )
120
+ if os .system (gs_command ) != 0 :
121
+ raise Exception ("model not present in the tank. Contact Nod Admin" )
122
+
123
+ if not check_dir_exists (model_dir_name , frontend = "torch" , dynamic = dyn_str ):
124
+ gs_download_model ()
125
+ else :
126
+ if not _internet_connected ():
127
+ print (
128
+ "No internet connection. Using the model already present in the tank."
129
+ )
130
+ else :
131
+ model_dir = os .path .join (WORKDIR , model_dir_name )
132
+ local_hash = str (np .load (os .path .join (model_dir , "hash.npy" )))
133
+ gs_hash = (
134
+ GSUTIL_PATH
135
+ + GSUTIL_FLAGS
136
+ + tank_url
137
+ + "/"
138
+ + model_dir_name
139
+ + "/hash.npy"
140
+ + " "
141
+ + os .path .join (model_dir , "upstream_hash.npy" )
142
+ )
143
+ if os .system (gs_hash ) != 0 :
144
+ raise Exception ("hash of the model not present in the tank." )
145
+ upstream_hash = str (
146
+ np .load (os .path .join (model_dir , "upstream_hash.npy" ))
147
+ )
148
+ if local_hash != upstream_hash :
149
+ if shark_args .update_tank == True :
150
+ gs_download_model ()
151
+ else :
152
+ print (
153
+ "Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
154
+ )
155
+
156
+ model_dir = os .path .join (WORKDIR , model_dir_name )
157
+ with open (
158
+ os .path .join (model_dir , model_name + dyn_str + "_torch.mlir" ),
159
+ mode = "rb" ,
160
+ ) as f :
161
+ mlir_file = f .read ()
162
+
163
+ function_name = str (np .load (os .path .join (model_dir , "function_name.npy" )))
164
+ inputs = np .load (os .path .join (model_dir , "inputs.npz" ))
165
+ golden_out = np .load (os .path .join (model_dir , "golden_out.npz" ))
166
+
167
+ inputs_tuple = tuple ([inputs [key ] for key in inputs ])
168
+ golden_out_tuple = tuple ([golden_out [key ] for key in golden_out ])
169
+ return mlir_file , function_name , inputs_tuple , golden_out_tuple
170
+
171
+
172
+ # Downloads the tflite model from gs://shark_tank dir.
173
+ def download_tflite_model (
174
+ model_name , dynamic = False , tank_url = "gs://shark_tank/latest"
175
+ ):
176
+ dyn_str = "_dynamic" if dynamic else ""
177
+ os .makedirs (WORKDIR , exist_ok = True )
178
+ model_dir_name = model_name + "_tflite"
179
+
180
+ def gs_download_model ():
181
+ gs_command = (
182
+ GSUTIL_PATH
183
+ + GSUTIL_FLAGS
184
+ + tank_url
185
+ + "/"
186
+ + model_dir_name
187
+ + ' "'
188
+ + WORKDIR
189
+ + '"'
190
+ )
191
+ if os .system (gs_command ) != 0 :
192
+ raise Exception ("model not present in the tank. Contact Nod Admin" )
118
193
119
194
if not check_dir_exists (
120
- model_dir_name , frontend = frontend , dynamic = dyn_str
195
+ model_dir_name , frontend = "tflite" , dynamic = dyn_str
121
196
):
122
- download_public_file ( full_gs_url , WORKDIR )
197
+ gs_download_model ( )
123
198
else :
124
199
if not _internet_connected ():
125
200
print (
@@ -128,34 +203,104 @@ def download_model(
128
203
else :
129
204
model_dir = os .path .join (WORKDIR , model_dir_name )
130
205
local_hash = str (np .load (os .path .join (model_dir , "hash.npy" )))
131
- gs_hash_url = (
132
- tank_url .rstrip ("/" ) + "/" + model_dir_name + "/hash.npy"
206
+ gs_hash = (
207
+ GSUTIL_PATH
208
+ + GSUTIL_FLAGS
209
+ + tank_url
210
+ + "/"
211
+ + model_dir_name
212
+ + "/hash.npy"
213
+ + " "
214
+ + os .path .join (model_dir , "upstream_hash.npy" )
133
215
)
134
- download_public_file (
135
- gs_hash_url , os .path .join (model_dir , "upstream_hash.npy" )
216
+ if os .system (gs_hash ) != 0 :
217
+ raise Exception ("hash of the model not present in the tank." )
218
+ upstream_hash = str (
219
+ np .load (os .path .join (model_dir , "upstream_hash.npy" ))
136
220
)
221
+ if local_hash != upstream_hash :
222
+ if shark_args .update_tank == True :
223
+ gs_download_model ()
224
+ else :
225
+ print (
226
+ "Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
227
+ )
228
+
229
+ model_dir = os .path .join (WORKDIR , model_dir_name )
230
+ with open (
231
+ os .path .join (model_dir , model_name + dyn_str + "_tflite.mlir" ),
232
+ mode = "rb" ,
233
+ ) as f :
234
+ mlir_file = f .read ()
235
+
236
+ function_name = str (np .load (os .path .join (model_dir , "function_name.npy" )))
237
+ inputs = np .load (os .path .join (model_dir , "inputs.npz" ))
238
+ golden_out = np .load (os .path .join (model_dir , "golden_out.npz" ))
239
+
240
+ inputs_tuple = tuple ([inputs [key ] for key in inputs ])
241
+ golden_out_tuple = tuple ([golden_out [key ] for key in golden_out ])
242
+ return mlir_file , function_name , inputs_tuple , golden_out_tuple
243
+
244
+
245
+ def download_tf_model (
246
+ model_name , tuned = None , tank_url = "gs://shark_tank/latest"
247
+ ):
248
+ model_name = model_name .replace ("/" , "_" )
249
+ os .makedirs (WORKDIR , exist_ok = True )
250
+ model_dir_name = model_name + "_tf"
251
+
252
+ def gs_download_model ():
253
+ gs_command = (
254
+ GSUTIL_PATH
255
+ + GSUTIL_FLAGS
256
+ + tank_url
257
+ + "/"
258
+ + model_dir_name
259
+ + ' "'
260
+ + WORKDIR
261
+ + '"'
262
+ )
263
+ if os .system (gs_command ) != 0 :
264
+ raise Exception ("model not present in the tank. Contact Nod Admin" )
265
+
266
+ if not check_dir_exists (model_dir_name , frontend = "tf" ):
267
+ gs_download_model ()
268
+ else :
269
+ if not _internet_connected ():
270
+ print (
271
+ "No internet connection. Using the model already present in the tank."
272
+ )
273
+ else :
274
+ model_dir = os .path .join (WORKDIR , model_dir_name )
275
+ local_hash = str (np .load (os .path .join (model_dir , "hash.npy" )))
276
+ gs_hash = (
277
+ GSUTIL_PATH
278
+ + GSUTIL_FLAGS
279
+ + tank_url
280
+ + "/"
281
+ + model_dir_name
282
+ + "/hash.npy"
283
+ + " "
284
+ + os .path .join (model_dir , "upstream_hash.npy" )
285
+ )
286
+ if os .system (gs_hash ) != 0 :
287
+ raise Exception ("hash of the model not present in the tank." )
137
288
upstream_hash = str (
138
289
np .load (os .path .join (model_dir , "upstream_hash.npy" ))
139
290
)
140
291
if local_hash != upstream_hash :
141
292
if shark_args .update_tank == True :
142
- download_public_file ( full_gs_url , WORKDIR )
293
+ gs_download_model ( )
143
294
else :
144
295
print (
145
296
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
146
297
)
147
298
148
299
model_dir = os .path .join (WORKDIR , model_dir_name )
149
- suffix = (
150
- "_" + frontend + ".mlir"
151
- if tuned is None
152
- else "_" + frontend + "_" + tuned + ".mlir"
153
- )
300
+ suffix = "_tf.mlir" if tuned is None else "_tf_" + tuned + ".mlir"
154
301
filename = os .path .join (model_dir , model_name + suffix )
155
302
if not os .path .isfile (filename ):
156
- filename = os .path .join (
157
- model_dir , model_name + "_" + frontend + ".mlir"
158
- )
303
+ filename = os .path .join (model_dir , model_name + "_tf.mlir" )
159
304
160
305
with open (filename , mode = "rb" ) as f :
161
306
mlir_file = f .read ()
0 commit comments