@@ -169,23 +169,15 @@ def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
169
169
RESNET_INPUT_SHAPE = [1 , 224 , 224 , 3 ]
170
170
EFFICIENTNET_INPUT_SHAPE = [1 , 384 , 384 , 3 ]
171
171
172
- tf_resnet_model = tf .keras .applications .resnet50 .ResNet50 (
173
- weights = "imagenet" ,
174
- include_top = True ,
175
- input_shape = tuple (RESNET_INPUT_SHAPE [1 :]),
176
- )
177
-
178
- tf_efficientnet_model = tf .keras .applications .efficientnet_v2 .EfficientNetV2S (
179
- weights = "imagenet" ,
180
- include_top = True ,
181
- input_shape = tuple (EFFICIENTNET_INPUT_SHAPE [1 :]),
182
- )
183
-
184
172
185
173
class ResNetModule (tf .Module ):
186
174
def __init__ (self ):
187
175
super (ResNetModule , self ).__init__ ()
188
- self .m = tf_resnet_model
176
+ self .m = tf .keras .applications .resnet50 .ResNet50 (
177
+ weights = "imagenet" ,
178
+ include_top = True ,
179
+ input_shape = tuple (RESNET_INPUT_SHAPE [1 :]),
180
+ )
189
181
self .m .predict = lambda x : self .m .call (x , training = False )
190
182
191
183
@tf .function (
@@ -205,7 +197,11 @@ def preprocess_input(self, image):
205
197
class EfficientNetModule (tf .Module ):
206
198
def __init__ (self ):
207
199
super (EfficientNetModule , self ).__init__ ()
208
- self .m = tf_efficientnet_model
200
+ self .m = tf .keras .applications .efficientnet_v2 .EfficientNetV2S (
201
+ weights = "imagenet" ,
202
+ include_top = True ,
203
+ input_shape = tuple (EFFICIENTNET_INPUT_SHAPE [1 :]),
204
+ )
209
205
self .m .predict = lambda x : self .m .call (x , training = False )
210
206
211
207
@tf .function (
0 commit comments