35
35
WORKDIR = os .path .join (home , ".local/shark_tank/" )
36
36
print (WORKDIR )
37
37
38
+
38
39
# Checks whether the directory and files exists.
39
40
def check_dir_exists (model_name , frontend = "torch" , dynamic = "" ):
40
41
model_dir = os .path .join (WORKDIR , model_name )
41
42
42
43
# Remove the _tf keyword from end.
43
44
if frontend in ["tf" , "tensorflow" ]:
44
45
model_name = model_name [:- 3 ]
46
+ elif frontend in ["tflite" ]:
47
+ model_name = model_name [:- 7 ]
48
+ elif frontend in ["torch" , "pytorch" ]:
49
+ model_name = model_name [:- 6 ]
45
50
46
51
if os .path .isdir (model_dir ):
47
52
if (
48
53
os .path .isfile (
49
- os .path .join (model_dir , model_name + dynamic + ".mlir" )
54
+ os .path .join (
55
+ model_dir ,
56
+ model_name + dynamic + "_" + str (frontend ) + ".mlir" ,
57
+ )
50
58
)
51
59
and os .path .isfile (os .path .join (model_dir , "function_name.npy" ))
52
60
and os .path .isfile (os .path .join (model_dir , "inputs.npz" ))
@@ -65,27 +73,28 @@ def download_torch_model(model_name, dynamic=False):
65
73
model_name = model_name .replace ("/" , "_" )
66
74
dyn_str = "_dynamic" if dynamic else ""
67
75
os .makedirs (WORKDIR , exist_ok = True )
76
+ model_dir_name = model_name + "_torch"
68
77
69
78
def gs_download_model ():
70
79
gs_command = (
71
80
'gsutil -o "GSUtil:parallel_process_count=1" cp -r gs://shark_tank'
72
81
+ "/"
73
- + model_name
82
+ + model_dir_name
74
83
+ " "
75
84
+ WORKDIR
76
85
)
77
86
if os .system (gs_command ) != 0 :
78
87
raise Exception ("model not present in the tank. Contact Nod Admin" )
79
88
80
- if not check_dir_exists (model_name , dyn_str ):
89
+ if not check_dir_exists (model_dir_name , frontend = "torch" , dynamic = dyn_str ):
81
90
gs_download_model ()
82
91
else :
83
- model_dir = os .path .join (WORKDIR , model_name )
92
+ model_dir = os .path .join (WORKDIR , model_dir_name )
84
93
local_hash = str (np .load (os .path .join (model_dir , "hash.npy" )))
85
94
gs_hash = (
86
95
'gsutil -o "GSUtil:parallel_process_count=1" cp gs://shark_tank'
87
96
+ "/"
88
- + model_name
97
+ + model_dir_name
89
98
+ "/hash.npy"
90
99
+ " "
91
100
+ os .path .join (model_dir , "upstream_hash.npy" )
@@ -98,8 +107,10 @@ def gs_download_model():
98
107
if local_hash != upstream_hash :
99
108
gs_download_model ()
100
109
101
- model_dir = os .path .join (WORKDIR , model_name )
102
- with open (os .path .join (model_dir , model_name + dyn_str + ".mlir" )) as f :
110
+ model_dir = os .path .join (WORKDIR , model_dir_name )
111
+ with open (
112
+ os .path .join (model_dir , model_name + dyn_str + "_torch.mlir" )
113
+ ) as f :
103
114
mlir_file = f .read ()
104
115
105
116
function_name = str (np .load (os .path .join (model_dir , "function_name.npy" )))
@@ -115,18 +126,21 @@ def gs_download_model():
115
126
def download_tflite_model (model_name , dynamic = False ):
116
127
dyn_str = "_dynamic" if dynamic else ""
117
128
os .makedirs (WORKDIR , exist_ok = True )
118
- if not check_dir_exists (model_name , dyn_str ):
129
+ model_dir_name = model_name + "_tflite"
130
+ if not check_dir_exists (
131
+ model_dir_name , frontend = "tflite" , dynamic = dyn_str
132
+ ):
119
133
gs_command = (
120
134
'gsutil -o "GSUtil:parallel_process_count=1" cp -r gs://shark_tank'
121
135
+ "/"
122
- + model_name
136
+ + model_dir_name
123
137
+ " "
124
138
+ WORKDIR
125
139
)
126
140
if os .system (gs_command ) != 0 :
127
141
raise Exception ("model not present in the tank. Contact Nod Admin" )
128
142
129
- model_dir = os .path .join (WORKDIR , model_name )
143
+ model_dir = os .path .join (WORKDIR , model_dir_name )
130
144
with open (
131
145
os .path .join (model_dir , model_name + dyn_str + "_tflite.mlir" )
132
146
) as f :
0 commit comments