Skip to content

Commit eb8114e

Browse files
authored
Initialize TF models locally (huggingface#610)
1 parent 616ee9b commit eb8114e

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

tank/model_utils_tf.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -169,23 +169,15 @@ def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
169169
RESNET_INPUT_SHAPE = [1, 224, 224, 3]
170170
EFFICIENTNET_INPUT_SHAPE = [1, 384, 384, 3]
171171

172-
tf_resnet_model = tf.keras.applications.resnet50.ResNet50(
173-
weights="imagenet",
174-
include_top=True,
175-
input_shape=tuple(RESNET_INPUT_SHAPE[1:]),
176-
)
177-
178-
tf_efficientnet_model = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
179-
weights="imagenet",
180-
include_top=True,
181-
input_shape=tuple(EFFICIENTNET_INPUT_SHAPE[1:]),
182-
)
183-
184172

185173
class ResNetModule(tf.Module):
186174
def __init__(self):
187175
super(ResNetModule, self).__init__()
188-
self.m = tf_resnet_model
176+
self.m = tf.keras.applications.resnet50.ResNet50(
177+
weights="imagenet",
178+
include_top=True,
179+
input_shape=tuple(RESNET_INPUT_SHAPE[1:]),
180+
)
189181
self.m.predict = lambda x: self.m.call(x, training=False)
190182

191183
@tf.function(
@@ -205,7 +197,11 @@ def preprocess_input(self, image):
205197
class EfficientNetModule(tf.Module):
206198
def __init__(self):
207199
super(EfficientNetModule, self).__init__()
208-
self.m = tf_efficientnet_model
200+
self.m = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
201+
weights="imagenet",
202+
include_top=True,
203+
input_shape=tuple(EFFICIENTNET_INPUT_SHAPE[1:]),
204+
)
209205
self.m.predict = lambda x: self.m.call(x, training=False)
210206

211207
@tf.function(

0 commit comments

Comments
 (0)