@@ -281,7 +281,7 @@ def update(self, input_pos, k_val, v_val):
281
281
return k_out , v_out
282
282
283
283
284
- class Model (nn .Module ):
284
+ class Model (ABC , nn .Module ):
285
285
"""
286
286
The entrance for model construction in torchchat.
287
287
"""
@@ -301,10 +301,10 @@ def build_model(self) -> nn.Module:
301
301
recipe = ModelRecipe .get_recipe (self .config .model_type )
302
302
modules = {}
303
303
for name , module_class in recipe .modules .items ():
304
- if isinstance (self .config .transformer_args [name ], dict ):
305
- modules [name ] = module_class (** self . config . transformer_args [ name ] )
304
+ if isinstance (config_args := self .config .transformer_args [name ], dict ):
305
+ modules [name ] = module_class (** config_args )
306
306
else :
307
- modules [name ] = module_class (self . config . transformer_args [ name ] )
307
+ modules [name ] = module_class (config_args )
308
308
309
309
return recipe .fusion_class (** modules )
310
310
@@ -369,7 +369,12 @@ def reset_caches(self):
369
369
370
370
371
371
class FlamingoModel (Model ):
372
- def forward (self , tokens : Tensor , encoder_input : Optional [Dict [str , Tensor ]] = None , encoder_mask : Optional [Tensor ] = None ) -> Tensor :
372
+ def forward (
373
+ self ,
374
+ tokens : Tensor ,
375
+ encoder_input : Optional [Dict [str , Tensor ]] = None ,
376
+ encoder_mask : Optional [Tensor ] = None ,
377
+ ) -> Tensor :
373
378
if encoder_input is None :
374
379
return self .model (tokens , encoder_mask = encoder_mask )
375
380
return self .model (tokens , encoder_input = encoder_input , encoder_mask = encoder_mask )
0 commit comments