Skip to content

Commit 6cf7db7

Browse files
committed
new factory func to produce Model from modelargs
1 parent dc40152 commit 6cf7db7

File tree

3 files changed

+7
-10
lines changed

3 files changed

+7
-10
lines changed

torchchat/generate.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,6 @@ def sample(
303303
temperature: float = 1.0,
304304
top_k: Optional[int] = None,
305305
):
306-
print(f"logits {logits}")
307306
if temperature == 0 and not need_probs:
308307
_, idx_next = torch.topk(logits[0, -1], k=1, dim=-1)
309308
return (idx_next, None)
@@ -323,21 +322,15 @@ def prefill(
323322
# logging.debug(f"x: {x}, input_pos: {input_pos}")
324323
width = x.size(1)
325324
assert input_pos.size(0) == width
326-
print("x: ", x, "input_pos: ", input_pos, "width: ", width)
327-
print("sequential_prefill: ", sequential_prefill)
328325

329326
if sequential_prefill:
330327
for i in range(width):
331-
print("i: ", i)
332328
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
333-
print("x_sliced: ", x_sliced, "ip_sliced: ", ip_sliced)
334329
# logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
335330
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])
336-
print("logits: ", logits)
337331
else:
338332
# input_pos: [B, S]
339333
logits = model(x, input_pos)
340-
print("logits: ", logits)
341334
# print(f"logits {logits.shape}")
342335

343336
# print(f"x: {x},\n input_pos: {input_pos}\n")

torchchat/model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,11 @@ def build_model(self) -> nn.Module:
310310

311311
@abstractmethod
312312
def forward(self, *args, **kwargs):
313-
pass
313+
raise NotImplementedError("forward method is not implemented")
314314

315315
@abstractmethod
316316
def setup_caches(self, *args, **kwargs):
317-
pass
317+
raise NotImplementedError("setup_caches method is not implemented")
318318

319319
@classmethod
320320
def _get_model_instance(cls, config: ModelArgs):
@@ -323,6 +323,10 @@ def _get_model_instance(cls, config: ModelArgs):
323323
raise ValueError("Unsupported model type:", str(config.model_type))
324324
return model_class(config)
325325

326+
@classmethod
327+
def from_model_args(cls, config: ModelArgs):
328+
return cls._get_model_instance(config)
329+
326330
@classmethod
327331
def from_name(cls, name: str):
328332
return cls._get_model_instance(ModelArgs.from_name(name))

torchchat/utils/gguf_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def load_model(gguf_file: str) -> torch.nn.Module:
558558
# metadata.get(f"{arch}.rope.dimension_count", None)
559559

560560
with torch.device("meta"):
561-
model = Model(model_args)
561+
model = Model.from_model_args(model_args)
562562
return model
563563

564564

0 commit comments

Comments
 (0)