Skip to content

Commit 1599c2b

Browse files
committed
solve comments
1 parent 6cf7db7 commit 1599c2b

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

torchchat/cli/builder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737

3838
# bypass the import issue before torchao is ready on macos
3939
try:
40-
from torchtune.training import set_default_dtype
4140
from torchtune.models.convert_weights import meta_to_tune
4241
except:
4342
pass

torchchat/model.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def update(self, input_pos, k_val, v_val):
281281
return k_out, v_out
282282

283283

284-
class Model(nn.Module):
284+
class Model(ABC, nn.Module):
285285
"""
286286
The entrance for model construction in torchchat.
287287
"""
@@ -301,10 +301,10 @@ def build_model(self) -> nn.Module:
301301
recipe = ModelRecipe.get_recipe(self.config.model_type)
302302
modules = {}
303303
for name, module_class in recipe.modules.items():
304-
if isinstance(self.config.transformer_args[name], dict):
305-
modules[name] = module_class(**self.config.transformer_args[name])
304+
if isinstance(config_args := self.config.transformer_args[name], dict):
305+
modules[name] = module_class(**config_args)
306306
else:
307-
modules[name] = module_class(self.config.transformer_args[name])
307+
modules[name] = module_class(config_args)
308308

309309
return recipe.fusion_class(**modules)
310310

@@ -369,7 +369,12 @@ def reset_caches(self):
369369

370370

371371
class FlamingoModel(Model):
372-
def forward(self, tokens: Tensor, encoder_input: Optional[Dict[str, Tensor]] = None, encoder_mask: Optional[Tensor] = None) -> Tensor:
372+
def forward(
373+
self,
374+
tokens: Tensor,
375+
encoder_input: Optional[Dict[str, Tensor]] = None,
376+
encoder_mask: Optional[Tensor] = None,
377+
) -> Tensor:
373378
if encoder_input is None:
374379
return self.model(tokens, encoder_mask=encoder_mask)
375380
return self.model(tokens, encoder_input=encoder_input, encoder_mask=encoder_mask)

0 commit comments

Comments
 (0)