diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 712b91d93..4f3f5727c 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -35,6 +35,13 @@ from torchchat.utils.measure_time import measure_time from torchchat.utils.quantize import quantize_model +# bypass the import issue before torchao is ready on macos +try: + from torchtune.models.convert_weights import meta_to_tune +except: + pass + + @dataclass class BuilderArgs: @@ -328,11 +335,15 @@ def _load_model_default(builder_args, only_config=False): assert not builder_args.gguf_path model = _init_model_on_meta_device(builder_args) - # checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True) - cps = [] - if builder_args.checkpoint_dir is not None: + + if builder_args.params_table and builder_args.params_table.endswith("Tune"): + print("Loading Tune checkpoint") + meta_checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True) + checkpoint = meta_to_tune(meta_checkpoint) + elif builder_args.checkpoint_dir is not None: # Load multiple checkpoint; ignore the single path. builder_args.checkpoint_path = None + cps = [] for i in range(4): cp_name = f"consolidated.{i}.pth" print(f"Loading {cp_name}") @@ -363,10 +374,10 @@ def _load_model_default(builder_args, only_config=False): if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): checkpoint = checkpoint["model"] - - checkpoint = {"text_transformer." + k: v for k, v in checkpoint.items()} + checkpoint = {"model." + k: v for k, v in checkpoint.items()} model.load_state_dict(checkpoint, assign=True, strict=True) + return model @@ -534,7 +545,9 @@ def _initialize_model( if builder_args.setup_caches: with torch.device(builder_args.device): model.setup_caches( - max_batch_size=1, max_seq_length=max_seq_length or model.config.transformer_args["text"].max_seq_length + max_batch_size=1, + max_seq_length=max_seq_length + or model.config.transformer_args["text"].max_seq_length, ) model.to(dtype=builder_args.precision) diff --git a/torchchat/generate.py b/torchchat/generate.py index 1d9114d67..67e9b9ae8 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -116,6 +116,7 @@ class GeneratorArgs: speculate_k: int = 5 sequential_prefill: bool = False max_autotune: bool = False + is_torchtune_model: bool = False def __post_init__(self): if self.compile_prefill and self.sequential_prefill: @@ -161,6 +162,7 @@ def from_args(cls, args): speculate_k=args.speculate_k, sequential_prefill=sequential_prefill, max_autotune=args.max_autotune, + is_torchtune_model=args.model and args.model.endswith("tune"), ) @@ -197,6 +199,8 @@ def __init__( self.profile = profile self.quantize = quantize self.draft_quantize = draft_quantize + self.is_torchtune_model = generator_args.is_torchtune_model + self.dtype = builder_args.precision # global print # from tp import maybe_init_dist @@ -263,7 +267,10 @@ def __init__( else: self.draft_model = None - self.tokenizer_args.validate_model(self.model) + # torchtune model does not contain essential info for validation + # TODO: refactor model config to be more generic + if not self.is_torchtune_model: + self.tokenizer_args.validate_model(self.model) self.tokenizer_args.validate_model(self.draft_model, "draft model") generator_args.validate_build(self.builder_args) generator_args.validate_build(self.speculative_builder_args, "draft model") @@ -295,7 +302,7 @@ def sample( need_probs: bool, temperature: float = 1.0, top_k: Optional[int] = None, - ): + ): if temperature == 0 and not need_probs: _, idx_next = torch.topk(logits[0, -1], k=1, dim=-1) return (idx_next, None) @@ -517,7 +524,10 @@ def generate( if start_pos == 0: model = model.to(device=device) with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + if self.is_torchtune_model: + model.setup_caches(max_batch_size=1, dtype=self.dtype) + else: + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) if is_speculative and draft_model is not model: draft_model.setup_caches( max_batch_size=1, max_seq_length=max_seq_length @@ -686,7 +696,12 @@ def chat( self.system_prompt = None # Set up our max_seq_length - if generator_args.chat_mode: + + # This is a hack to get around the fact that different models have different ways to record their max_seq_length and might be wrong + # TODO: unify the max_seq_length config representation. + if generator_args.is_torchtune_model: + max_seq_length = self.model.config.transformer_args["text"]["max_seq_len"] + elif generator_args.chat_mode: max_seq_length = self.model.config.transformer_args["text"].max_seq_length print( f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye" diff --git a/torchchat/model.py b/torchchat/model.py index 500f2c71c..f0910f54a 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -11,6 +11,7 @@ from enum import Enum from pathlib import Path from typing import Callable, Dict, Optional, Union +from abc import ABC, abstractmethod import torch import torch.nn as nn @@ -33,13 +34,20 @@ try: from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder from torchtune.modules.model_fusion import DeepFusionModel + from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder except: pass config_path = Path(f"{str(Path(__file__).parent)}/model_params") +def identity(**kwargs): + if len(kwargs) != 1: + raise ValueError("Only one argument is expected") + return list(kwargs.values())[0] + class ModelType(Enum): TextOnly = "text_only" + Llama3_1 = "llama3_1" Flamingo = "flamingo" # Type for objects that can generate nn.Module instance @@ -72,9 +80,18 @@ class ModelRecipe: def _text_only(cls): return cls( model_type=ModelType.TextOnly, - modules={'text_transformer': Transformer}, - fusion_class=nn.Identity, + modules={'text': Transformer}, + fusion_class=identity, + ) + + @classmethod + def _llama3_1(cls): + return cls( + model_type=ModelType.Llama3_1, + modules={'text': llama3_1_builder}, + fusion_class=identity, ) + @classmethod def _flamingo(cls): return cls( @@ -92,6 +109,8 @@ def get_recipe(cls, model_type): return cls._text_only() elif model_type == ModelType.Flamingo: return cls._flamingo() + elif model_type == ModelType.Llama3_1: + return cls._llama3_1() else: raise ValueError(f"Can not find the model recipe for {model_type}") @@ -184,11 +203,7 @@ def from_params(cls, params_path): except TypeError: # try to interpret as a dict of transformer configs model_type = ModelType(loaded_params["model_type"]) - - # Currently only supporting flamingo model - assert model_type == ModelType.Flamingo transformer_args = {k: v for k, v in loaded_params.items() if k != "model_type"} - return cls(transformer_args, model_type) @classmethod @@ -266,18 +281,14 @@ def update(self, input_pos, k_val, v_val): return k_out, v_out -class Model(nn.Module): +class Model(ABC, nn.Module): """ The entrance for model construction in torchchat. """ def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config - # TODO: unify the model init logic - if config.model_type == ModelType.TextOnly: - self.text_transformer = Transformer(config.transformer_args["text"]) - else: - self.model = self.build_model() + self.model = self.build_model() def build_model(self) -> nn.Module: """ @@ -290,50 +301,43 @@ def build_model(self) -> nn.Module: recipe = ModelRecipe.get_recipe(self.config.model_type) modules = {} for name, module_class in recipe.modules.items(): - modules[name] = module_class(**self.config.transformer_args[name]) - + if isinstance(config_args := self.config.transformer_args[name], dict): + modules[name] = module_class(**config_args) + else: + modules[name] = module_class(config_args) + return recipe.fusion_class(**modules) + + @abstractmethod + def forward(self, *args, **kwargs): + raise NotImplementedError("forward method is not implemented") - def forward(self, - tokens: Optional[Tensor] = None, - input_pos: Optional[Tensor] = None, - encoder_input: Optional[Dict[str, Tensor]] = None, - encoder_mask: Optional[Tensor] = None) -> Tensor: + @abstractmethod + def setup_caches(self, *args, **kwargs): + raise NotImplementedError("setup_caches method is not implemented") - if self.config.model_type == ModelType.TextOnly: - return self.text_transformer(tokens, input_pos) - else: - assert self.config.model_type == ModelType.Flamingo - if input_pos: - warnings.warn("input_pos is not used for Flamingo model. Ignoring it.") - if encoder_input is None: - return self.model(tokens, encoder_mask = encoder_mask) - return self.model(tokens, encoder_input=encoder_input, encoder_mask = encoder_mask) - - def setup_caches(self, max_batch_size, max_seq_length=None, dtype=None): - if self.config.model_type == ModelType.TextOnly: - self.text_transformer.setup_caches(max_batch_size, max_seq_length) - else: - assert self.config.model_type == ModelType.Flamingo - if max_seq_length is not None: - warnings.warn("max_seq_length is not used for Flamingo model. Ignoring it.") - self.model.setup_caches(max_batch_size, dtype=dtype) - - def reset_caches(self): - assert self.config.model_type == ModelType.Flamingo - self.model.reset_caches() + @classmethod + def _get_model_instance(cls, config: ModelArgs): + model_class = MODEL_TYPE_TO_CLASS.get(config.model_type) + if model_class is None: + raise ValueError("Unsupported model type:", str(config.model_type)) + return model_class(config) + + @classmethod + def from_model_args(cls, config: ModelArgs): + return cls._get_model_instance(config) @classmethod def from_name(cls, name: str): - return cls(ModelArgs.from_name(name)) + return cls._get_model_instance(ModelArgs.from_name(name)) @classmethod def from_table(cls, name: str): - return cls(ModelArgs.from_table(name)) + return cls._get_model_instance(ModelArgs.from_table(name)) @classmethod def from_params(cls, params_path: str): - return cls(ModelArgs.from_params(params_path)) + return cls._get_model_instance(ModelArgs.from_params(params_path)) @classmethod def from_gguf(cls, gguf_path: str, **kwargs): @@ -345,6 +349,49 @@ def from_gguf(cls, gguf_path: str, **kwargs): return model +class TextOnlyModel(Model): + def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + return self.model(tokens, input_pos) + + def setup_caches(self, max_batch_size, max_seq_length): + self.model.setup_caches(max_batch_size, max_seq_length) + + +class Llama31Model(Model): + def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + return self.model(tokens=tokens, input_pos=input_pos) + + def setup_caches(self, max_batch_size, dtype): + self.model.setup_caches(max_batch_size, dtype=dtype) + + def reset_caches(self): + self.model.reset_caches() + + +class FlamingoModel(Model): + def forward( + self, + tokens: Tensor, + encoder_input: Optional[Dict[str, Tensor]] = None, + encoder_mask: Optional[Tensor] = None, + ) -> Tensor: + if encoder_input is None: + return self.model(tokens, encoder_mask=encoder_mask) + return self.model(tokens, encoder_input=encoder_input, encoder_mask=encoder_mask) + + def setup_caches(self, max_batch_size, dtype): + self.model.setup_caches(max_batch_size, dtype=dtype) + + def reset_caches(self): + self.model.reset_caches() + + +MODEL_TYPE_TO_CLASS = { + ModelType.TextOnly: TextOnlyModel, + ModelType.Flamingo: FlamingoModel, + ModelType.Llama3_1: Llama31Model, +} + class Transformer(nn.Module): def __init__(self, config: TransformerArgs) -> None: super().__init__() diff --git a/torchchat/model_config/models.json b/torchchat/model_config/models.json index ab0abb7d6..ca8c5acdf 100644 --- a/torchchat/model_config/models.json +++ b/torchchat/model_config/models.json @@ -57,6 +57,18 @@ "distribution_path": "meta-llama/Meta-Llama-3.1-70B-Instruct", "transformer_params_key": "Meta-Llama-3.1-70B" }, + "meta-llama/Meta-Llama-3.1-8B-Instruct-Tune": { + "aliases": ["llama3.1-tune", "llama3.1-chat-tune", "llama3.1-instruct-tune"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "transformer_params_key": "Meta-Llama-3.1-8B-Tune" + }, + "meta-llama/Meta-Llama-3.1-70B-Instruct-Tune": { + "aliases": ["llama3.1-70b-tune"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "meta-llama/Meta-Llama-3.1-70B-Instruct", + "transformer_params_key": "Meta-Llama-3.1-70B-Tune" + }, "meta-llama/CodeLlama-7b-Python-hf": { "aliases": ["codellama", "codellama-7b"], "distribution_channel": "HuggingFaceSnapshot", diff --git a/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json b/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json new file mode 100644 index 000000000..c59961c63 --- /dev/null +++ b/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json @@ -0,0 +1,15 @@ +{ + "model_type": "llama3_1", + "text": { + "vocab_size": 128256, + "num_layers": 80, + "num_heads": 64, + "num_kv_heads": 8, + "embed_dim": 8192, + "max_seq_len": 8192, + "intermediate_dim": 28672, + "attn_dropout": 0.0, + "norm_eps": 1e-5, + "rope_base": 500000.0 + } +} diff --git a/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json b/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json new file mode 100644 index 000000000..e9ded77bd --- /dev/null +++ b/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json @@ -0,0 +1,15 @@ +{ + "model_type": "llama3_1", + "text": { + "vocab_size": 128256, + "num_layers": 32, + "num_heads": 32, + "num_kv_heads": 8, + "embed_dim": 4096, + "max_seq_len": 8192, + "intermediate_dim": 14336, + "attn_dropout": 0.0, + "norm_eps": 1e-5, + "rope_base": 500000.0 + } +} diff --git a/torchchat/utils/gguf_loader.py b/torchchat/utils/gguf_loader.py index 6e87a8a9f..c7b931dae 100644 --- a/torchchat/utils/gguf_loader.py +++ b/torchchat/utils/gguf_loader.py @@ -47,7 +47,7 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str: result = copy.deepcopy(gguf_name) for gguf_string, replacement in _name_replacements: result = result.replace(gguf_string, replacement) - result = "text_transformer." + result + result = "model." + result return result @@ -558,7 +558,7 @@ def load_model(gguf_file: str) -> torch.nn.Module: # metadata.get(f"{arch}.rope.dimension_count", None) with torch.device("meta"): - model = Model(model_args) + model = Model.from_model_args(model_args) return model