Skip to content

Commit c1cb7bb

Browse files
authored
Rename tank tflite/torch model dir (huggingface#219)
1 parent ff20dde commit c1cb7bb

File tree

2 files changed

+28
-12
lines changed

2 files changed

+28
-12
lines changed

generate_sharktank.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def save_torch_model(torch_model_list):
6161
model, input, _ = get_hf_model(torch_model_name)
6262

6363
torch_model_name = torch_model_name.replace("/", "_")
64-
torch_model_dir = os.path.join(WORKDIR, str(torch_model_name))
64+
torch_model_dir = os.path.join(
65+
WORKDIR, str(torch_model_name) + "_torch"
66+
)
6567
os.makedirs(torch_model_dir, exist_ok=True)
6668

6769
mlir_importer = SharkImporter(
@@ -136,7 +138,7 @@ def save_tflite_model(tflite_model_list):
136138
print("tflite_model_name", tflite_model_name)
137139
print("tflite_model_link", tflite_model_link)
138140
tflite_model_name_dir = os.path.join(
139-
WORKDIR, str(tflite_model_name)
141+
WORKDIR, str(tflite_model_name) + "_tflite"
140142
)
141143
os.makedirs(tflite_model_name_dir, exist_ok=True)
142144
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")

shark/shark_downloader.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,26 @@
3535
WORKDIR = os.path.join(home, ".local/shark_tank/")
3636
print(WORKDIR)
3737

38+
3839
# Checks whether the directory and files exists.
3940
def check_dir_exists(model_name, frontend="torch", dynamic=""):
4041
model_dir = os.path.join(WORKDIR, model_name)
4142

4243
# Remove the _tf keyword from end.
4344
if frontend in ["tf", "tensorflow"]:
4445
model_name = model_name[:-3]
46+
elif frontend in ["tflite"]:
47+
model_name = model_name[:-7]
48+
elif frontend in ["torch", "pytorch"]:
49+
model_name = model_name[:-6]
4550

4651
if os.path.isdir(model_dir):
4752
if (
4853
os.path.isfile(
49-
os.path.join(model_dir, model_name + dynamic + ".mlir")
54+
os.path.join(
55+
model_dir,
56+
model_name + dynamic + "_" + str(frontend) + ".mlir",
57+
)
5058
)
5159
and os.path.isfile(os.path.join(model_dir, "function_name.npy"))
5260
and os.path.isfile(os.path.join(model_dir, "inputs.npz"))
@@ -65,27 +73,28 @@ def download_torch_model(model_name, dynamic=False):
6573
model_name = model_name.replace("/", "_")
6674
dyn_str = "_dynamic" if dynamic else ""
6775
os.makedirs(WORKDIR, exist_ok=True)
76+
model_dir_name = model_name + "_torch"
6877

6978
def gs_download_model():
7079
gs_command = (
7180
'gsutil -o "GSUtil:parallel_process_count=1" cp -r gs://shark_tank'
7281
+ "/"
73-
+ model_name
82+
+ model_dir_name
7483
+ " "
7584
+ WORKDIR
7685
)
7786
if os.system(gs_command) != 0:
7887
raise Exception("model not present in the tank. Contact Nod Admin")
7988

80-
if not check_dir_exists(model_name, dyn_str):
89+
if not check_dir_exists(model_dir_name, frontend="torch", dynamic=dyn_str):
8190
gs_download_model()
8291
else:
83-
model_dir = os.path.join(WORKDIR, model_name)
92+
model_dir = os.path.join(WORKDIR, model_dir_name)
8493
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
8594
gs_hash = (
8695
'gsutil -o "GSUtil:parallel_process_count=1" cp gs://shark_tank'
8796
+ "/"
88-
+ model_name
97+
+ model_dir_name
8998
+ "/hash.npy"
9099
+ " "
91100
+ os.path.join(model_dir, "upstream_hash.npy")
@@ -98,8 +107,10 @@ def gs_download_model():
98107
if local_hash != upstream_hash:
99108
gs_download_model()
100109

101-
model_dir = os.path.join(WORKDIR, model_name)
102-
with open(os.path.join(model_dir, model_name + dyn_str + ".mlir")) as f:
110+
model_dir = os.path.join(WORKDIR, model_dir_name)
111+
with open(
112+
os.path.join(model_dir, model_name + dyn_str + "_torch.mlir")
113+
) as f:
103114
mlir_file = f.read()
104115

105116
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
@@ -115,18 +126,21 @@ def gs_download_model():
115126
def download_tflite_model(model_name, dynamic=False):
116127
dyn_str = "_dynamic" if dynamic else ""
117128
os.makedirs(WORKDIR, exist_ok=True)
118-
if not check_dir_exists(model_name, dyn_str):
129+
model_dir_name = model_name + "_tflite"
130+
if not check_dir_exists(
131+
model_dir_name, frontend="tflite", dynamic=dyn_str
132+
):
119133
gs_command = (
120134
'gsutil -o "GSUtil:parallel_process_count=1" cp -r gs://shark_tank'
121135
+ "/"
122-
+ model_name
136+
+ model_dir_name
123137
+ " "
124138
+ WORKDIR
125139
)
126140
if os.system(gs_command) != 0:
127141
raise Exception("model not present in the tank. Contact Nod Admin")
128142

129-
model_dir = os.path.join(WORKDIR, model_name)
143+
model_dir = os.path.join(WORKDIR, model_dir_name)
130144
with open(
131145
os.path.join(model_dir, model_name + dyn_str + "_tflite.mlir")
132146
) as f:

0 commit comments

Comments
 (0)