From 0dae9eff0de35381b29235b2683944d704748cf3 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 27 Aug 2024 16:37:47 -0700 Subject: [PATCH 01/43] added model source and type for torchtune flamingo support --- .DS_Store | Bin 0 -> 6148 bytes build/.DS_Store | Bin 0 -> 6148 bytes build/model.py | 21 ++++++++++++++++++--- flamingo.json | 28 ++++++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 3 deletions(-) create mode 100644 .DS_Store create mode 100644 build/.DS_Store create mode 100644 flamingo.json diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..16d875f048b003f89438f610b00ecfc05d69253a GIT binary patch literal 6148 zcmeH~J&pn~427Thk&tL8DbsL(y+MTF1YBUnPJ=WO#fm;h=h<(?J6KYLuU74ZjM<6h0 J5P^Rs@C1L66M6su literal 0 HcmV?d00001 diff --git a/build/.DS_Store b/build/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..658fefb065d303740d88fb70cd071aca7d48bab7 GIT binary patch literal 6148 zcmeHKyH3ME5S)b+k!T_+%KHWWz>2~b@BzF`g@X|hqIJc0@oCIH8q3f`qKRguy}8@l zxzkPI^#ZW%M}G}$0Icbb`0`kSWhr|)s-4^PA6I1WybNdYM!1*Cu!kOD_5P^ENzar6lyN&zWwstWk`q0t??!YMI6 z9UP(sAg&k=<2-r^V)Fp8E1VJ;p;=OiNwsP*Ea{B5%IgZJ#H7Qj`LMd#szb53o#(em zhjob>rGONeD{!96jo1Gh`XBxOoTQZ$kOC*AfUS17yDgtowRQ10ueFVSPxqW}x*O*~ o;SlAR80DA?FUMDrlzGkP-0upf#Go@CbfSI+To;)X_-_S%08E`2 None: super().__init__() self.config = config - self.text_transformer = Transformer(config.text_transformer_args) + if config.source == "native": + assert config.model_type == ModelType.TextOnly, "only text-only model is supported natively. For Flamingo, use torchtune" + self.text_transformer = Transformer(config.transformer_args["text"]) def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: return self.text_transformer(idx, input_pos) diff --git a/flamingo.json b/flamingo.json new file mode 100644 index 000000000..d5ecb1dd8 --- /dev/null +++ b/flamingo.json @@ -0,0 +1,28 @@ +{ + "source": "torchtune", + "model_name": "flamingo", + "encoder": { + "patch_size": 14, + "num_heads": 16, + "clip_embed_dim": 1280, + "clip_num_layers": 32, + "clip_hidden_states": [3, 7, 15, 23, 30], + "decoder_embed_dim": 4096, + "num_layers_projection": 8, + "tile_size": 448, + "max_num_tiles": 4, + "in_channels": 3 + }, + "decoder": { + "vocab_size": 128256, + "num_layers": 32, + "fusion_interval": 4, + "num_special_tokens": 8, + "num_heads": 32, + "num_kv_heads": 8, + "embed_dim": 4096, + "max_seq_len": 8192, + "rope_base": 500000.0, + "intermediate_dim": 14336 + } +} From 87397e32f5de19366b03aef51e939b26a8a225ef Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 27 Aug 2024 16:41:50 -0700 Subject: [PATCH 02/43] added model source and type for torchtune flamingo support --- flamingo.json | 28 ---------------------------- 1 file changed, 28 deletions(-) delete mode 100644 flamingo.json diff --git a/flamingo.json b/flamingo.json deleted file mode 100644 index d5ecb1dd8..000000000 --- a/flamingo.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "source": "torchtune", - "model_name": "flamingo", - "encoder": { - "patch_size": 14, - "num_heads": 16, - "clip_embed_dim": 1280, - "clip_num_layers": 32, - "clip_hidden_states": [3, 7, 15, 23, 30], - "decoder_embed_dim": 4096, - "num_layers_projection": 8, - "tile_size": 448, - "max_num_tiles": 4, - "in_channels": 3 - }, - "decoder": { - "vocab_size": 128256, - "num_layers": 32, - "fusion_interval": 4, - "num_special_tokens": 8, - "num_heads": 32, - "num_kv_heads": 8, - "embed_dim": 4096, - "max_seq_len": 8192, - "rope_base": 500000.0, - "intermediate_dim": 14336 - } -} From 0f6161418d9324a1abb4aae987b89727d3ef15df Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 27 Aug 2024 16:48:16 -0700 Subject: [PATCH 03/43] grab missing enum --- build/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/build/model.py b/build/model.py index 48347d5b8..6a788d611 100644 --- a/build/model.py +++ b/build/model.py @@ -17,6 +17,7 @@ from torch.nn import functional as F from build.utils import find_multiple, get_precision +from enum import Enum config_path = Path(f"{str(Path(__file__).parent)}/known_model_params") From d7f3a88cc67c4c54966801e3ac08702b37d6d2ac Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 27 Aug 2024 16:54:32 -0700 Subject: [PATCH 04/43] fix ModelArgs init --- build/model.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/build/model.py b/build/model.py index 6a788d611..67858cad3 100644 --- a/build/model.py +++ b/build/model.py @@ -29,7 +29,7 @@ class ModelType(Enum): class ModelSource(Enum): Native = "native" - Gguf = "gguf" + Torchtune = "torchtune" @dataclass @@ -79,9 +79,9 @@ def from_params(cls, params): @dataclass class ModelArgs: - model_source: ModelSource = ModelSource.Native - model_type: ModelType = ModelType.TextOnly - transformer_args: Dict[str, TransformerArgs] = None + model_source: ModelSource + model_type: ModelType + transformer_args: Dict[str, TransformerArgs] def __post_init__(self): assert self.text_transformer_args is not None @@ -94,19 +94,21 @@ def from_params(cls, params_path): try: # try to interpret as a single transformer config + transformer_args: Dict[str, TransformerArgs] = {} transformer_args['text'] = TransformerArgs.from_params( loaded_params ) except TypeError: # try to interpret as a dict of transformer configs # now only support flamingo model + assert False, "flamingo model is not supported yet" for name, params in loaded_params.items(): if name == "text": text_transformer_args = TransformerArgs.from_params(params) else: raise ValueError(f"Unknown transformer name {name}") - return cls(text_transformer_args) + return cls(model_source, model_type, transformer_args) @classmethod def from_table(cls, name: str): From 994b1480e7cd7215566c16c868d88c8cd7633c74 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 27 Aug 2024 18:37:07 -0700 Subject: [PATCH 05/43] create init func for ModelArgs for BC --- build/known_model_params/.DS_Store | Bin 0 -> 6148 bytes build/model.py | 34 ++++++++++++++++++++--------- 2 files changed, 24 insertions(+), 10 deletions(-) create mode 100644 build/known_model_params/.DS_Store diff --git a/build/known_model_params/.DS_Store b/build/known_model_params/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 GIT binary patch literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0 None: + self.model_source = model_source + self.model_type = model_type + if isinstance(transformer_args, TransformerArgs): + self.transformer_args = {"text": transformer_args} + else: + self.transformer_args = transformer_args + def __post_init__(self): assert self.text_transformer_args is not None assert type(self.text_transformer_args) == TransformerArgs @@ -95,9 +109,7 @@ def from_params(cls, params_path): try: # try to interpret as a single transformer config transformer_args: Dict[str, TransformerArgs] = {} - transformer_args['text'] = TransformerArgs.from_params( - loaded_params - ) + transformer_args["text"] = TransformerArgs.from_params(loaded_params) except TypeError: # try to interpret as a dict of transformer configs # now only support flamingo model @@ -190,12 +202,14 @@ def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config if config.source == "native": - assert config.model_type == ModelType.TextOnly, "only text-only model is supported natively. For Flamingo, use torchtune" + assert ( + config.model_type == ModelType.TextOnly + ), "only text-only model is supported natively. For Flamingo, use torchtune" self.text_transformer = Transformer(config.transformer_args["text"]) def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: return self.text_transformer(idx, input_pos) - + def setup_caches(self, max_batch_size, max_seq_length): self.text_transformer.setup_caches(max_batch_size, max_seq_length) @@ -481,11 +495,10 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: # ExecuTorch model components # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -try: - from executorch.extension.pybindings import portable_lib as exec_lib - +try: # ET changed the way it's loading the custom ops so it's not included in portable_lib but has to be loaded separately. - from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # no-qa + from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # no-qa + from executorch.extension.pybindings import portable_lib as exec_lib class PTEModel(nn.Module): def __init__(self, config, path) -> None: @@ -506,5 +519,6 @@ def forward(self, x, input_pos): def setup_caches(self, max_batch_size, max_seq_length): pass + except: pass From d184e68782ea866da2798f375d781dd2181e5c5e Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 27 Aug 2024 19:25:34 -0700 Subject: [PATCH 06/43] update pipeline for ModleSource and ModelType --- build/builder.py | 4 ++-- build/model.py | 31 ++++++++++++++++++++++++++++--- distributed/parallelize_llama.py | 2 +- eval.py | 2 +- export.py | 2 +- generate.py | 5 ++--- torchchat/usages/openai_api.py | 4 ++-- 7 files changed, 37 insertions(+), 13 deletions(-) diff --git a/build/builder.py b/build/builder.py index 635cca152..79bd958bf 100644 --- a/build/builder.py +++ b/build/builder.py @@ -225,7 +225,7 @@ def validate_model( is_tiktoken = self.is_tiktoken is_sentencepiece = self.is_sentencepiece - use_tiktoken = model.config.text_transformer_args.use_tiktoken + use_tiktoken = model.text_transformer.config.use_tiktoken if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken): raise RuntimeError( @@ -529,7 +529,7 @@ def _initialize_model( if builder_args.setup_caches: with torch.device(builder_args.device): model.setup_caches( - max_batch_size=1, max_seq_length=max_seq_length or model.config.text_transformer_args.max_seq_length + max_batch_size=1, max_seq_length=max_seq_length or model.text_transformer.config.max_seq_length ) model.to(dtype=builder_args.precision) diff --git a/build/model.py b/build/model.py index e64d8f0f6..2a9b9d6df 100644 --- a/build/model.py +++ b/build/model.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import Dict, Optional +from typing import Dict, Optional, Union import torch import torch.nn as nn @@ -90,6 +90,8 @@ def __init__( model_source: ModelSource = ModelSource.Native, model_type: ModelType = ModelType.TextOnly, ) -> None: + self._sanity_check(transformer_args, model_source, model_type) + self.model_source = model_source self.model_type = model_type if isinstance(transformer_args, TransformerArgs): @@ -100,6 +102,23 @@ def __init__( def __post_init__(self): assert self.text_transformer_args is not None assert type(self.text_transformer_args) == TransformerArgs + + def _sanity_check( + self, + transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], + model_source: ModelSource, + model_type: ModelType, + ) -> None: + + assert isinstance(model_source, ModelSource) + assert isinstance(model_type, ModelType) + assert isinstance(transformer_args, (TransformerArgs, dict)) + + assert model_source in [ModelSource.Native], "only native model is supported" + assert ( + model_type == ModelType.TextOnly + ), "only text-only model is supported natively. For Flamingo, use torchtune" + @classmethod def from_params(cls, params_path): @@ -110,17 +129,21 @@ def from_params(cls, params_path): # try to interpret as a single transformer config transformer_args: Dict[str, TransformerArgs] = {} transformer_args["text"] = TransformerArgs.from_params(loaded_params) + model_source = ModelSource.Native + model_type = ModelType.TextOnly except TypeError: # try to interpret as a dict of transformer configs # now only support flamingo model assert False, "flamingo model is not supported yet" + model_source = loaded_params["model_source"] + model_type = loaded_params["model_type"] for name, params in loaded_params.items(): if name == "text": text_transformer_args = TransformerArgs.from_params(params) else: raise ValueError(f"Unknown transformer name {name}") - return cls(model_source, model_type, transformer_args) + return cls(transformer_args, model_source, model_type) @classmethod def from_table(cls, name: str): @@ -201,11 +224,13 @@ class Model(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config - if config.source == "native": + if config.model_source == ModelSource.Native: assert ( config.model_type == ModelType.TextOnly ), "only text-only model is supported natively. For Flamingo, use torchtune" self.text_transformer = Transformer(config.transformer_args["text"]) + else: + assert False, "only native model is supported" def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: return self.text_transformer(idx, input_pos) diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py index f0d12d769..716cfe37d 100644 --- a/distributed/parallelize_llama.py +++ b/distributed/parallelize_llama.py @@ -59,7 +59,7 @@ def apply_tp( # after we apply TP to the model. Because we don't want to change model code # when applying TP. We need to have change to ensure KVCache has the correct # size as k and v. - model.config.text_transformer_args.n_local_heads = model.config.text_transformer_args.n_local_heads // tp_mesh.size() + model.text_transformer.config.n_local_heads = model.text_transformer.config.n_local_heads // tp_mesh.size() # Apply tensor parallelism to every transformer block for transformer_block in model.layers: diff --git a/eval.py b/eval.py index b79757ea0..6032d7ac0 100644 --- a/eval.py +++ b/eval.py @@ -58,7 +58,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( T = prompt.size(0) T_new = T + max_new_tokens if max_seq_length is None: - max_seq_length = min(T_new, model.config.text_transformer_args.block_size) + max_seq_length = min(T_new, model.text_transformer.config.block_size) device, dtype = prompt.device, prompt.dtype # create an empty tensor of the expected final shape and diff --git a/export.py b/export.py index 2b85fbb11..f52d71df4 100644 --- a/export.py +++ b/export.py @@ -56,7 +56,7 @@ def export_for_server( torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device), ) - seq = Dim("seq", min=1, max=model.config.text_transformer_args.max_seq_length) + seq = Dim("seq", min=1, max=model.text_transformer.config.max_seq_length) # Specify that the first dimension of each input is that batch size dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}} else: diff --git a/generate.py b/generate.py index ee69e574f..f09b10382 100644 --- a/generate.py +++ b/generate.py @@ -676,7 +676,7 @@ def chat( self.system_prompt = None # Set up our max_seq_length if generator_args.chat_mode: - max_seq_length = self.model.config.text_transformer_args.max_seq_length + max_seq_length = self.model.text_transformer.config.max_seq_length print( f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye" ) @@ -689,7 +689,7 @@ def chat( else: max_seq_length = min( encoded.size(0) + generator_args.max_new_tokens, - self.model.config.text_transformer_args.block_size, + self.model.text_transformer.config.block_size, ) max_seq_length = ( @@ -903,4 +903,3 @@ def main(args): check_args(args, verb) args = arg_init(args) main(args) - diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index ab3e15e0b..005eb7ab9 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -232,11 +232,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.max_seq_length = ( - self.model.config.text_transformer_args.max_seq_length + self.model.text_transformer.config.max_seq_length + self.speculative_builder_args.speculate_k + 1 if self.draft_model is not None - else self.model.config.text_transformer_args.max_seq_length + else self.model.text_transformer.config.max_seq_length ) # The System fingerprint is a unique identifier for the model and its configuration. self.system_fingerprint = ( From 0d8e36826fbaf8b5feae3d3cf6d203204c2075fd Mon Sep 17 00:00:00 2001 From: Gasoonjia Date: Wed, 28 Aug 2024 10:56:00 -0700 Subject: [PATCH 07/43] revert lintrunner update on ET --- build/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/build/model.py b/build/model.py index 2a9b9d6df..10821155e 100644 --- a/build/model.py +++ b/build/model.py @@ -521,9 +521,10 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ try: + from executorch.extension.pybindings import portable_lib as exec_lib + # ET changed the way it's loading the custom ops so it's not included in portable_lib but has to be loaded separately. from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # no-qa - from executorch.extension.pybindings import portable_lib as exec_lib class PTEModel(nn.Module): def __init__(self, config, path) -> None: From 6c78850d00f7895ebf864e587ce2cce9418a2d2a Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Wed, 28 Aug 2024 12:42:20 -0700 Subject: [PATCH 08/43] introduce flamingo modules form torchtune --- build/builder.py | 4 ++-- build/model.py | 38 ++++++++++++++++++++------------ distributed/parallelize_llama.py | 2 +- eval.py | 2 +- export.py | 2 +- generate.py | 4 ++-- install_requirements.sh | 6 +++++ torchchat/usages/openai_api.py | 4 ++-- 8 files changed, 39 insertions(+), 23 deletions(-) diff --git a/build/builder.py b/build/builder.py index 79bd958bf..a72e4e13a 100644 --- a/build/builder.py +++ b/build/builder.py @@ -225,7 +225,7 @@ def validate_model( is_tiktoken = self.is_tiktoken is_sentencepiece = self.is_sentencepiece - use_tiktoken = model.text_transformer.config.use_tiktoken + use_tiktoken = model.config.transformer_args["text"].use_tiktoken if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken): raise RuntimeError( @@ -529,7 +529,7 @@ def _initialize_model( if builder_args.setup_caches: with torch.device(builder_args.device): model.setup_caches( - max_batch_size=1, max_seq_length=max_seq_length or model.text_transformer.config.max_seq_length + max_batch_size=1, max_seq_length=max_seq_length or model.config.transformer_args["text"].max_seq_length ) model.to(dtype=builder_args.precision) diff --git a/build/model.py b/build/model.py index 10821155e..df1c45bf6 100644 --- a/build/model.py +++ b/build/model.py @@ -19,6 +19,9 @@ from build.utils import find_multiple, get_precision +from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder +from torchtune.modules.model_fusion import DeepFusionModel + config_path = Path(f"{str(Path(__file__).parent)}/known_model_params") @@ -31,7 +34,6 @@ class ModelSource(Enum): Native = "native" Torchtune = "torchtune" - @dataclass class TransformerArgs: block_size: int = 2048 @@ -82,7 +84,7 @@ def from_params(cls, params): class ModelArgs: model_source: ModelSource model_type: ModelType - transformer_args: Dict[str, TransformerArgs] + transformer_args: Dict[str, Union[Dict, TransformerArgs]] def __init__( self, @@ -98,10 +100,6 @@ def __init__( self.transformer_args = {"text": transformer_args} else: self.transformer_args = transformer_args - - def __post_init__(self): - assert self.text_transformer_args is not None - assert type(self.text_transformer_args) == TransformerArgs def _sanity_check( self, @@ -133,15 +131,12 @@ def from_params(cls, params_path): model_type = ModelType.TextOnly except TypeError: # try to interpret as a dict of transformer configs - # now only support flamingo model - assert False, "flamingo model is not supported yet" model_source = loaded_params["model_source"] model_type = loaded_params["model_type"] - for name, params in loaded_params.items(): - if name == "text": - text_transformer_args = TransformerArgs.from_params(params) - else: - raise ValueError(f"Unknown transformer name {name}") + + # now only support flamingo model + assert model_source == ModelSource.Torchtune and model_type == ModelType.Flamingo + transformer_args = {k: v for k, v in loaded_params.items() if k != "model_source" and k != "model_type"} return cls(transformer_args, model_source, model_type) @@ -230,7 +225,22 @@ def __init__(self, config: ModelArgs) -> None: ), "only text-only model is supported natively. For Flamingo, use torchtune" self.text_transformer = Transformer(config.transformer_args["text"]) else: - assert False, "only native model is supported" + assert config.model_source == ModelSource.Torchtune + assert config.model_type == ModelType.Flamingo, "currently only Flamingo model is supported from torchtune" + + assert "encoder" in config.transformer_args and "decoder" in config.transformer_args, "config is missing essential transformer args for Flamingo model" + encoder = flamingo_vision_encoder( + *config.transformer_args["encoder"] + ) + decoder = flamingo_decoder( + *config.transformer_args["decoder"] + ) + self.model = DeepFusionModel( + encoder=encoder, + decoder=decoder, + ) + + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: return self.text_transformer(idx, input_pos) diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py index 716cfe37d..24dad2679 100644 --- a/distributed/parallelize_llama.py +++ b/distributed/parallelize_llama.py @@ -59,7 +59,7 @@ def apply_tp( # after we apply TP to the model. Because we don't want to change model code # when applying TP. We need to have change to ensure KVCache has the correct # size as k and v. - model.text_transformer.config.n_local_heads = model.text_transformer.config.n_local_heads // tp_mesh.size() + model.config.transformer_args["text"].n_local_heads = model.config.transformer_args["text"].n_local_heads // tp_mesh.size() # Apply tensor parallelism to every transformer block for transformer_block in model.layers: diff --git a/eval.py b/eval.py index 6032d7ac0..066d5d085 100644 --- a/eval.py +++ b/eval.py @@ -58,7 +58,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( T = prompt.size(0) T_new = T + max_new_tokens if max_seq_length is None: - max_seq_length = min(T_new, model.text_transformer.config.block_size) + max_seq_length = min(T_new, model.config.transformer_args["text"].block_size) device, dtype = prompt.device, prompt.dtype # create an empty tensor of the expected final shape and diff --git a/export.py b/export.py index f52d71df4..afa4168b0 100644 --- a/export.py +++ b/export.py @@ -56,7 +56,7 @@ def export_for_server( torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device), ) - seq = Dim("seq", min=1, max=model.text_transformer.config.max_seq_length) + seq = Dim("seq", min=1, max=model.config.transformer_args["text"].max_seq_length) # Specify that the first dimension of each input is that batch size dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}} else: diff --git a/generate.py b/generate.py index f09b10382..bbb995344 100644 --- a/generate.py +++ b/generate.py @@ -676,7 +676,7 @@ def chat( self.system_prompt = None # Set up our max_seq_length if generator_args.chat_mode: - max_seq_length = self.model.text_transformer.config.max_seq_length + max_seq_length = self.model.config.transformer_args["text"].max_seq_length print( f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye" ) @@ -689,7 +689,7 @@ def chat( else: max_seq_length = min( encoded.size(0) + generator_args.max_new_tokens, - self.model.text_transformer.config.block_size, + self.model.config.transformer_args["text"].block_size, ) max_seq_length = ( diff --git a/install_requirements.sh b/install_requirements.sh index 9baac5ec0..87cd1085e 100755 --- a/install_requirements.sh +++ b/install_requirements.sh @@ -78,6 +78,12 @@ REQUIREMENTS_TO_INSTALL=( "${REQUIREMENTS_TO_INSTALL[@]}" ) +# Install torchtune separately with the --pre flag +( + set -x + $PIP_EXECUTABLE install --pre torchtune --extra-index-url "${TORCH_NIGHTLY_URL}" --no-cache-dir +) + # For torchao need to install from github since nightly build doesn't have macos build. # TODO: Remove this and install nightly build, once it supports macos ( diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index 005eb7ab9..319ce7939 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -232,11 +232,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.max_seq_length = ( - self.model.text_transformer.config.max_seq_length + self.model.config.transformer_args["text"].max_seq_length + self.speculative_builder_args.speculate_k + 1 if self.draft_model is not None - else self.model.text_transformer.config.max_seq_length + else self.model.config.transformer_args["text"].max_seq_length ) # The System fingerprint is a unique identifier for the model and its configuration. self.system_fingerprint = ( From 2691bae1af362337d3f2d29baafdffd52501b980 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Wed, 28 Aug 2024 14:43:44 -0700 Subject: [PATCH 09/43] back up to move to linux --- install_requirements.sh | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/install_requirements.sh b/install_requirements.sh index 87cd1085e..b30af0e6e 100755 --- a/install_requirements.sh +++ b/install_requirements.sh @@ -78,18 +78,19 @@ REQUIREMENTS_TO_INSTALL=( "${REQUIREMENTS_TO_INSTALL[@]}" ) -# Install torchtune separately with the --pre flag -( - set -x - $PIP_EXECUTABLE install --pre torchtune --extra-index-url "${TORCH_NIGHTLY_URL}" --no-cache-dir -) - -# For torchao need to install from github since nightly build doesn't have macos build. -# TODO: Remove this and install nightly build, once it supports macos +# Install torchtune from Philip forked repository due to flamingo components have not been landed yet. +# TODO: Use torchtune official repository instead, when flamingo components have been merged. ( set -x - $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@e11201a62669f582d81cdb33e031a07fb8dfc4f3 + $PIP_EXECUTABLE install git+https://github.com/pbontrager/torchtune.git@flamingo_components ) + +# # For torchao need to install from github since nightly build doesn't have macos build. +# # TODO: Remove this and install nightly build, once it supports macos +# ( +# set -x +# $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@e11201a62669f582d81cdb33e031a07fb8dfc4f3 +# ) if [[ -x "$(command -v nvidia-smi)" ]]; then ( set -x From ba960f0a910a999eca7ea2baf0ff29223adc8ae1 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 28 Aug 2024 17:10:34 -0700 Subject: [PATCH 10/43] mitigate building issue --- build/builder.py | 2 +- build/convert_hf_checkpoint.py | 2 +- dist_run.py | 2 +- install_requirements.sh | 5 +++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/build/builder.py b/build/builder.py index a72e4e13a..e02a9b8d9 100644 --- a/build/builder.py +++ b/build/builder.py @@ -23,7 +23,7 @@ ParallelDims, parallelize_llama, ) -from quantization.quantize import quantize_model +# from quantization.quantize import quantize_model from torch.distributed.device_mesh import DeviceMesh from utils.measure_time import measure_time diff --git a/build/convert_hf_checkpoint.py b/build/convert_hf_checkpoint.py index de176af56..ab8e23f00 100644 --- a/build/convert_hf_checkpoint.py +++ b/build/convert_hf_checkpoint.py @@ -32,7 +32,7 @@ def convert_hf_checkpoint( if model_name is None: model_name = model_dir.name - config = ModelArgs.from_name(model_name).text_transformer_args + config = ModelArgs.from_name(model_name).transformer_args['text'] print(f"Model config {config.__dict__}") # Load the json file containing weight mapping diff --git a/dist_run.py b/dist_run.py index 34732a008..533273170 100644 --- a/dist_run.py +++ b/dist_run.py @@ -16,7 +16,7 @@ # Model config def main(): - config = ModelArgs.from_name("Transformer-2-7b-chat-hf").text_transformer_args + config = ModelArgs.from_name("Transformer-2-7b-chat-hf").transformer_args['text'] print(config) # Construct a device mesh with available devices (multi-host or single host) diff --git a/install_requirements.sh b/install_requirements.sh index b30af0e6e..884ef53be 100755 --- a/install_requirements.sh +++ b/install_requirements.sh @@ -67,7 +67,8 @@ fi # pip packages needed by exir. REQUIREMENTS_TO_INSTALL=( - torch=="2.5.0.${NIGHTLY_VERSION}" + torch=="2.4.0" + torchvision ) # Install the requirements. `--extra-index-url` tells pip to look for package @@ -89,7 +90,7 @@ REQUIREMENTS_TO_INSTALL=( # # TODO: Remove this and install nightly build, once it supports macos # ( # set -x -# $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@e11201a62669f582d81cdb33e031a07fb8dfc4f3 +# $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@cfabc13e72fd03934e62a2a03903bc1678235bed # ) if [[ -x "$(command -v nvidia-smi)" ]]; then ( From 8b3a684723459f5d5e2847c71293fc9463a3e38b Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Fri, 30 Aug 2024 15:43:02 -0700 Subject: [PATCH 11/43] pass local test --- build/model.py | 43 ++++++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/build/model.py b/build/model.py index df1c45bf6..5aa31dc17 100644 --- a/build/model.py +++ b/build/model.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import json import os +import warnings from dataclasses import dataclass from enum import Enum @@ -112,12 +113,6 @@ def _sanity_check( assert isinstance(model_type, ModelType) assert isinstance(transformer_args, (TransformerArgs, dict)) - assert model_source in [ModelSource.Native], "only native model is supported" - assert ( - model_type == ModelType.TextOnly - ), "only text-only model is supported natively. For Flamingo, use torchtune" - - @classmethod def from_params(cls, params_path): with open(params_path, "r") as f: @@ -131,8 +126,8 @@ def from_params(cls, params_path): model_type = ModelType.TextOnly except TypeError: # try to interpret as a dict of transformer configs - model_source = loaded_params["model_source"] - model_type = loaded_params["model_type"] + model_source = ModelSource(loaded_params["model_source"]) + model_type = ModelType(loaded_params["model_type"]) # now only support flamingo model assert model_source == ModelSource.Torchtune and model_type == ModelType.Flamingo @@ -230,23 +225,37 @@ def __init__(self, config: ModelArgs) -> None: assert "encoder" in config.transformer_args and "decoder" in config.transformer_args, "config is missing essential transformer args for Flamingo model" encoder = flamingo_vision_encoder( - *config.transformer_args["encoder"] + **config.transformer_args["encoder"] ) decoder = flamingo_decoder( - *config.transformer_args["decoder"] + **config.transformer_args["decoder"] ) self.model = DeepFusionModel( encoder=encoder, decoder=decoder, ) - - - def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: - return self.text_transformer(idx, input_pos) - - def setup_caches(self, max_batch_size, max_seq_length): - self.text_transformer.setup_caches(max_batch_size, max_seq_length) + def forward(self, idx: Tensor, imgs: Optional[Tensor] = None, aspect_ratio: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) -> Tensor: + if self.config.model_type == ModelType.TextOnly: + return self.text_transformer(idx, input_pos) + else: + assert self.config.model_type == ModelType.Flamingo + if imgs is None: + return self.model(idx, input_pos = input_pos) + return self.model(idx, encoder_input={"images": imgs, "aspect_ratio": aspect_ratio}, input_pos = input_pos) + + def setup_caches(self, max_batch_size, max_seq_length=None, dtype=None): + if self.config.model_type == ModelType.TextOnly: + self.text_transformer.setup_caches(max_batch_size, max_seq_length) + else: + assert self.config.model_type == ModelType.Flamingo + if max_seq_length is not None: + warnings.warn("max_seq_length is not used for Flamingo model. Ignoring it.") + self.model.setup_caches(max_batch_size, dtype=dtype) + + def reset_caches(self): + assert self.config.model_type == ModelType.Flamingo + self.model.reset_caches() @classmethod def from_name(cls, name: str): From e7fa7b45d848c6f9d92b61621673c446e463cab3 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 3 Sep 2024 09:37:38 -0700 Subject: [PATCH 12/43] structual model builder --- torchchat/model.py | 81 +++++++++++++++++++++++++++------------------- 1 file changed, 47 insertions(+), 34 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index a056f74ad..a4c32ab2c 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -25,15 +25,43 @@ config_path = Path(f"{str(Path(__file__).parent)}/model_params") - class ModelType(Enum): TextOnly = "text_only" Flamingo = "flamingo" -class ModelSource(Enum): - Native = "native" - Torchtune = "torchtune" +@dataclass +class ModelRecipe: + model_type: ModelType + modules: dict + fusion_class: torch.nn.Module + + @classmethod + def text_only(cls): + return cls( + model_type=ModelType.TextOnly, + modules={'text_transformer': Transformer}, + fusion_class=nn.Identity, + ) + @classmethod + def flamingo(cls): + return cls( + model_type=ModelType.Flamingo, + modules={ + 'encoder': flamingo_vision_encoder, + 'decoder': flamingo_decoder + }, + fusion_class=DeepFusionModel, + ) + + @classmethod + def get_recipe(cls, model_type): + if model_type == ModelType.TextOnly: + return cls.text_only() + elif model_type == ModelType.Flamingo: + return cls.flamingo() + else: + raise ValueError(f"Can not find the model recipe for {model_type}") @dataclass class TransformerArgs: @@ -83,19 +111,16 @@ def from_params(cls, params): @dataclass class ModelArgs: - model_source: ModelSource model_type: ModelType transformer_args: Dict[str, Union[Dict, TransformerArgs]] def __init__( self, transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], - model_source: ModelSource = ModelSource.Native, model_type: ModelType = ModelType.TextOnly, ) -> None: - self._sanity_check(transformer_args, model_source, model_type) + self._sanity_check(transformer_args, model_type) - self.model_source = model_source self.model_type = model_type if isinstance(transformer_args, TransformerArgs): self.transformer_args = {"text": transformer_args} @@ -105,11 +130,8 @@ def __init__( def _sanity_check( self, transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], - model_source: ModelSource, model_type: ModelType, ) -> None: - - assert isinstance(model_source, ModelSource) assert isinstance(model_type, ModelType) assert isinstance(transformer_args, (TransformerArgs, dict)) @@ -122,18 +144,16 @@ def from_params(cls, params_path): # try to interpret as a single transformer config transformer_args: Dict[str, TransformerArgs] = {} transformer_args["text"] = TransformerArgs.from_params(loaded_params) - model_source = ModelSource.Native model_type = ModelType.TextOnly except TypeError: # try to interpret as a dict of transformer configs - model_source = ModelSource(loaded_params["model_source"]) model_type = ModelType(loaded_params["model_type"]) # now only support flamingo model - assert model_source == ModelSource.Torchtune and model_type == ModelType.Flamingo - transformer_args = {k: v for k, v in loaded_params.items() if k != "model_source" and k != "model_type"} + assert model_type == ModelType.Flamingo + transformer_args = {k: v for k, v in loaded_params.items() if k != "model_type"} - return cls(transformer_args, model_source, model_type) + return cls(transformer_args, model_type) @classmethod def from_table(cls, name: str): @@ -214,26 +234,19 @@ class Model(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config - if config.model_source == ModelSource.Native: - assert ( - config.model_type == ModelType.TextOnly - ), "only text-only model is supported natively. For Flamingo, use torchtune" + # TODO: unify the model init logic + if config.model_type == ModelType.TextOnly: self.text_transformer = Transformer(config.transformer_args["text"]) else: - assert config.model_source == ModelSource.Torchtune - assert config.model_type == ModelType.Flamingo, "currently only Flamingo model is supported from torchtune" - - assert "encoder" in config.transformer_args and "decoder" in config.transformer_args, "config is missing essential transformer args for Flamingo model" - encoder = flamingo_vision_encoder( - **config.transformer_args["encoder"] - ) - decoder = flamingo_decoder( - **config.transformer_args["decoder"] - ) - self.model = DeepFusionModel( - encoder=encoder, - decoder=decoder, - ) + self.model = self.build_model() + + def build_model(self): + recipe = ModelRecipe.get_recipe(self.config.model_type) + modules = {} + for name, module_class in recipe.modules.items(): + modules[name] = module_class(**self.config.transformer_args[name]) + + return recipe.fusion_class(**modules) def forward(self, idx: Tensor, imgs: Optional[Tensor] = None, aspect_ratio: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) -> Tensor: if self.config.model_type == ModelType.TextOnly: From c179bcb9b96e7aa00a3f65731da4dec7a4a86428 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 5 Sep 2024 16:17:17 -0700 Subject: [PATCH 13/43] update torchtune address --- install/install_requirements.sh | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/install/install_requirements.sh b/install/install_requirements.sh index b6830d091..422515dfc 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -67,8 +67,7 @@ fi # pip packages needed by exir. REQUIREMENTS_TO_INSTALL=( - torch=="2.4.0" - torchvision + torch=="2.5.0.${NIGHTLY_VERSION}" ) # Install the requirements. `--extra-index-url` tells pip to look for package @@ -79,19 +78,18 @@ REQUIREMENTS_TO_INSTALL=( "${REQUIREMENTS_TO_INSTALL[@]}" ) -# Install torchtune from Philip forked repository due to flamingo components have not been landed yet. -# TODO: Use torchtune official repository instead, when flamingo components have been merged. +# Install torchtune from github to get the latest feature ( set -x - $PIP_EXECUTABLE install git+https://github.com/pbontrager/torchtune.git@flamingo_components + $PIP_EXECUTABLE install git+https://github.com/pytorch/torchtune.git ) -# # For torchao need to install from github since nightly build doesn't have macos build. -# # TODO: Remove this and install nightly build, once it supports macos -# ( -# set -x -# $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@cfabc13e72fd03934e62a2a03903bc1678235bed -# ) +# For torchao need to install from github since nightly build doesn't have macos build. +# TODO: Remove this and install nightly build, once it supports macos +( + set -x + $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@e11201a62669f582d81cdb33e031a07fb8dfc4f3 +) if [[ -x "$(command -v nvidia-smi)" ]]; then ( set -x From 5ead73b60765d6f0bdf682a66884b488d1652398 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 5 Sep 2024 19:01:32 -0700 Subject: [PATCH 14/43] update install requirement --- install/install_requirements.sh | 13 ++++++++----- torchchat/model.py | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/install/install_requirements.sh b/install/install_requirements.sh index 422515dfc..cba2d439a 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -68,6 +68,8 @@ fi # pip packages needed by exir. REQUIREMENTS_TO_INSTALL=( torch=="2.5.0.${NIGHTLY_VERSION}" + torchvision + ) # Install the requirements. `--extra-index-url` tells pip to look for package @@ -78,18 +80,19 @@ REQUIREMENTS_TO_INSTALL=( "${REQUIREMENTS_TO_INSTALL[@]}" ) -# Install torchtune from github to get the latest feature +# For torchao need to install from github since nightly build doesn't have macos build. +# TODO: Remove this and install nightly build, once it supports macos ( set -x - $PIP_EXECUTABLE install git+https://github.com/pytorch/torchtune.git + $PIP_EXECUTABLE install --pre torchao --index-url "${TORCH_NIGHTLY_URL}" ) -# For torchao need to install from github since nightly build doesn't have macos build. -# TODO: Remove this and install nightly build, once it supports macos +# Install torchtune from github to get the latest feature ( set -x - $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@e11201a62669f582d81cdb33e031a07fb8dfc4f3 + $PIP_EXECUTABLE install git+https://github.com/pytorch/torchtune.git ) + if [[ -x "$(command -v nvidia-smi)" ]]; then ( set -x diff --git a/torchchat/model.py b/torchchat/model.py index a4c32ab2c..2cff9c470 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -248,7 +248,7 @@ def build_model(self): return recipe.fusion_class(**modules) - def forward(self, idx: Tensor, imgs: Optional[Tensor] = None, aspect_ratio: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) -> Tensor: + def forward(self, idx: Optional[Tensor] = None, input_pos: Optional[Tensor] = None, batch: Optional[dict] = None) -> Tensor: if self.config.model_type == ModelType.TextOnly: return self.text_transformer(idx, input_pos) else: From 882c336598baf3f3ed65b638b09307c49b0732c9 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Fri, 6 Sep 2024 01:55:27 -0700 Subject: [PATCH 15/43] support new torchtune flamingo component --- install/install_requirements.sh | 9 +-------- torchchat/model.py | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/install/install_requirements.sh b/install/install_requirements.sh index cba2d439a..6beee14dd 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -69,7 +69,7 @@ fi REQUIREMENTS_TO_INSTALL=( torch=="2.5.0.${NIGHTLY_VERSION}" torchvision - + torchao ) # Install the requirements. `--extra-index-url` tells pip to look for package @@ -80,13 +80,6 @@ REQUIREMENTS_TO_INSTALL=( "${REQUIREMENTS_TO_INSTALL[@]}" ) -# For torchao need to install from github since nightly build doesn't have macos build. -# TODO: Remove this and install nightly build, once it supports macos -( - set -x - $PIP_EXECUTABLE install --pre torchao --index-url "${TORCH_NIGHTLY_URL}" -) - # Install torchtune from github to get the latest feature ( set -x diff --git a/torchchat/model.py b/torchchat/model.py index 2cff9c470..5bfb6999c 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -248,14 +248,21 @@ def build_model(self): return recipe.fusion_class(**modules) - def forward(self, idx: Optional[Tensor] = None, input_pos: Optional[Tensor] = None, batch: Optional[dict] = None) -> Tensor: + def forward(self, + tokens: Optional[Tensor] = None, + input_pos: Optional[Tensor] = None, + encoder_input: Optional[Dict[str, Tensor]] = None, + encoder_mask: Optional[Tensor] = None) -> Tensor: + if self.config.model_type == ModelType.TextOnly: - return self.text_transformer(idx, input_pos) + return self.text_transformer(tokens, input_pos) else: assert self.config.model_type == ModelType.Flamingo - if imgs is None: - return self.model(idx, input_pos = input_pos) - return self.model(idx, encoder_input={"images": imgs, "aspect_ratio": aspect_ratio}, input_pos = input_pos) + if input_pos: + warnings.warn("input_pos is not used for Flamingo model. Ignoring it.") + if encoder_input is None: + return self.model(tokens, encoder_mask = encoder_mask) + return self.model(tokens, encoder_input=encoder_input, encoder_mask = encoder_mask) def setup_caches(self, max_batch_size, max_seq_length=None, dtype=None): if self.config.model_type == ModelType.TextOnly: From 952b8bd4759bb820237fd0b245133435adeeaacf Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Fri, 6 Sep 2024 14:37:55 -0700 Subject: [PATCH 16/43] specific version for vision and ao --- install/install_requirements.sh | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/install/install_requirements.sh b/install/install_requirements.sh index 6beee14dd..8f557bfc7 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -41,13 +41,19 @@ fi ) # Since torchchat often uses main-branch features of pytorch, only the nightly -# pip versions will have the required features. The NIGHTLY_VERSION value should +# pip versions will have the required features. The PYTORCH_NIGHTLY_VERSION value should # agree with the third-party/pytorch pinned submodule commit. # # NOTE: If a newly-fetched version of the executorch repo changes the value of -# NIGHTLY_VERSION, you should re-run this script to install the necessary +# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary # package versions. -NIGHTLY_VERSION=dev20240814 +PYTORCH_NIGHTLY_VERSION=dev20240814 + +# Nightly version for torchvision +VISION_NIGHTLY_VERSION=dev20240814 + +# Nightly version for torchao +AO_NIGHTLY_VERSION=dev20240905 # Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same ( @@ -67,9 +73,9 @@ fi # pip packages needed by exir. REQUIREMENTS_TO_INSTALL=( - torch=="2.5.0.${NIGHTLY_VERSION}" - torchvision - torchao + torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}" + torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}" + torchao=="0.5.0.${AO_NIGHTLY_VERSION}" ) # Install the requirements. `--extra-index-url` tells pip to look for package From e764111027674eeb88f51eeddb16ff15f7c0265a Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 9 Sep 2024 13:30:44 -0700 Subject: [PATCH 17/43] unify text-only model generation pipeline --- torchchat/model.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index 5bfb6999c..aed173ca9 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -25,6 +25,11 @@ config_path = Path(f"{str(Path(__file__).parent)}/model_params") +def identity(**kwargs): + if len(kwargs) != 1: + raise ValueError("Only one argument is expected") + return list(kwargs.values())[0] + class ModelType(Enum): TextOnly = "text_only" Flamingo = "flamingo" @@ -40,8 +45,8 @@ class ModelRecipe: def text_only(cls): return cls( model_type=ModelType.TextOnly, - modules={'text_transformer': Transformer}, - fusion_class=nn.Identity, + modules={'text': Transformer}, + fusion_class=identity, ) @classmethod def flamingo(cls): @@ -152,7 +157,6 @@ def from_params(cls, params_path): # now only support flamingo model assert model_type == ModelType.Flamingo transformer_args = {k: v for k, v in loaded_params.items() if k != "model_type"} - return cls(transformer_args, model_type) @classmethod @@ -236,7 +240,7 @@ def __init__(self, config: ModelArgs) -> None: self.config = config # TODO: unify the model init logic if config.model_type == ModelType.TextOnly: - self.text_transformer = Transformer(config.transformer_args["text"]) + self.text_transformer = self.build_model() else: self.model = self.build_model() @@ -244,8 +248,11 @@ def build_model(self): recipe = ModelRecipe.get_recipe(self.config.model_type) modules = {} for name, module_class in recipe.modules.items(): - modules[name] = module_class(**self.config.transformer_args[name]) - + if isinstance(self.config.transformer_args[name], dict): + modules[name] = module_class(**self.config.transformer_args[name]) + else: + modules[name] = module_class(self.config.transformer_args[name]) + return recipe.fusion_class(**modules) def forward(self, From 9679a5b3df95f8ece21b26b050b48e7042ded056 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 9 Sep 2024 13:57:02 -0700 Subject: [PATCH 18/43] convert installation back and bypass torchtune --- install/install_requirements.sh | 22 +++++++--------------- torchchat/model.py | 9 +++++++-- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/install/install_requirements.sh b/install/install_requirements.sh index 8f557bfc7..7174fffa4 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -41,19 +41,13 @@ fi ) # Since torchchat often uses main-branch features of pytorch, only the nightly -# pip versions will have the required features. The PYTORCH_NIGHTLY_VERSION value should +# pip versions will have the required features. The NIGHTLY_VERSION value should # agree with the third-party/pytorch pinned submodule commit. # # NOTE: If a newly-fetched version of the executorch repo changes the value of -# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary +# NIGHTLY_VERSION, you should re-run this script to install the necessary # package versions. -PYTORCH_NIGHTLY_VERSION=dev20240814 - -# Nightly version for torchvision -VISION_NIGHTLY_VERSION=dev20240814 - -# Nightly version for torchao -AO_NIGHTLY_VERSION=dev20240905 +NIGHTLY_VERSION=dev20240814 # Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same ( @@ -73,9 +67,7 @@ fi # pip packages needed by exir. REQUIREMENTS_TO_INSTALL=( - torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}" - torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}" - torchao=="0.5.0.${AO_NIGHTLY_VERSION}" + torch=="2.5.0.${NIGHTLY_VERSION}" ) # Install the requirements. `--extra-index-url` tells pip to look for package @@ -86,12 +78,12 @@ REQUIREMENTS_TO_INSTALL=( "${REQUIREMENTS_TO_INSTALL[@]}" ) -# Install torchtune from github to get the latest feature +# For torchao need to install from github since nightly build doesn't have macos build. +# TODO: Remove this and install nightly build, once it supports macos ( set -x - $PIP_EXECUTABLE install git+https://github.com/pytorch/torchtune.git + $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@e11201a62669f582d81cdb33e031a07fb8dfc4f3 ) - if [[ -x "$(command -v nvidia-smi)" ]]; then ( set -x diff --git a/torchchat/model.py b/torchchat/model.py index 5bfb6999c..26c6868cd 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -20,8 +20,13 @@ from torchchat.utils.build_utils import find_multiple, get_precision -from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder -from torchtune.modules.model_fusion import DeepFusionModel +# bypass the import issue, if any +# TODO: remove this once the torchao is ready on macos +try: + from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder + from torchtune.modules.model_fusion import DeepFusionModel +except: + pass config_path = Path(f"{str(Path(__file__).parent)}/model_params") From a3f08eadeb9ee00e12b7e0e81ad5a550b8592d9a Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 9 Sep 2024 16:19:33 -0700 Subject: [PATCH 19/43] restructual model definition --- torchchat/cli/builder.py | 2 +- torchchat/model.py | 101 +++++++++++++++++++-------------- torchchat/utils/gguf_loader.py | 2 +- 3 files changed, 61 insertions(+), 44 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 712b91d93..da2c66b14 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -364,7 +364,7 @@ def _load_model_default(builder_args, only_config=False): if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): checkpoint = checkpoint["model"] - checkpoint = {"text_transformer." + k: v for k, v in checkpoint.items()} + checkpoint = {"model." + k: v for k, v in checkpoint.items()} model.load_state_dict(checkpoint, assign=True, strict=True) return model diff --git a/torchchat/model.py b/torchchat/model.py index 32b7e7f13..e9701d9bf 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -11,6 +11,7 @@ from enum import Enum from pathlib import Path from typing import Dict, Optional, Union +from abc import ABC, abstractmethod import torch import torch.nn as nn @@ -250,16 +251,12 @@ def update(self, input_pos, k_val, v_val): return k_out, v_out -class Model(nn.Module): +class Model(ABC, nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config - # TODO: unify the model init logic - if config.model_type == ModelType.TextOnly: - self.text_transformer = self.build_model() - else: - self.model = self.build_model() - + self.model = self.build_model() + def build_model(self): recipe = ModelRecipe.get_recipe(self.config.model_type) modules = {} @@ -270,57 +267,77 @@ def build_model(self): modules[name] = module_class(self.config.transformer_args[name]) return recipe.fusion_class(**modules) + + @abstractmethod + def forward(self, *args, **kwargs): + pass - def forward(self, - tokens: Optional[Tensor] = None, - input_pos: Optional[Tensor] = None, - encoder_input: Optional[Dict[str, Tensor]] = None, - encoder_mask: Optional[Tensor] = None) -> Tensor: - - if self.config.model_type == ModelType.TextOnly: - return self.text_transformer(tokens, input_pos) - else: - assert self.config.model_type == ModelType.Flamingo - if input_pos: - warnings.warn("input_pos is not used for Flamingo model. Ignoring it.") - if encoder_input is None: - return self.model(tokens, encoder_mask = encoder_mask) - return self.model(tokens, encoder_input=encoder_input, encoder_mask = encoder_mask) - - def setup_caches(self, max_batch_size, max_seq_length=None, dtype=None): - if self.config.model_type == ModelType.TextOnly: - self.text_transformer.setup_caches(max_batch_size, max_seq_length) - else: - assert self.config.model_type == ModelType.Flamingo - if max_seq_length is not None: - warnings.warn("max_seq_length is not used for Flamingo model. Ignoring it.") - self.model.setup_caches(max_batch_size, dtype=dtype) + @abstractmethod + def setup_caches(self, *args, **kwargs): + pass - def reset_caches(self): - assert self.config.model_type == ModelType.Flamingo - self.model.reset_caches() + def _load_state_dict_hook(self, state_dict, prefix, *args, **kwargs): + """Apply extra prefix to the state_dict keys. + + Args: + state_dict (dict): The state dictionary where the model parameters are stored. + prefix (str): The prefix to add to each key in the state_dict. + """ + new_state_dict = {} + for key in state_dict.keys(): + new_key = f"{prefix}{key}" + new_state_dict[new_key] = state_dict[key] + return new_state_dict + + @classmethod + def _get_model_instance(cls, config: ModelArgs): + model_class = MODEL_TYPE_TO_CLASS.get(config.model_type) + if model_class is None: + raise ValueError("Unsupported model type") + return model_class(config) @classmethod def from_name(cls, name: str): - return cls(ModelArgs.from_name(name)) + return cls._get_model_instance(ModelArgs.from_name(name)) @classmethod def from_table(cls, name: str): - return cls(ModelArgs.from_table(name)) + return cls._get_model_instance(ModelArgs.from_table(name)) @classmethod def from_params(cls, params_path: str): - return cls(ModelArgs.from_params(params_path)) + return cls._get_model_instance(ModelArgs.from_params(name)) @classmethod def from_gguf(cls, gguf_path: str, **kwargs): - from torchchat.utils.gguf_loader import load_model_and_state_dict + return cls._get_model_instance(ModelArgs.from_gguf(name)) + + +class TextOnlyModel(Model): + def forward(self, tokens: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) -> Tensor: + return self.model(tokens, input_pos) + + def setup_caches(self, max_batch_size, max_seq_length): + self.model.setup_caches(max_batch_size, max_seq_length) + + +class FlamingoModel(Model): + def forward(self, tokens: Optional[Tensor] = None, encoder_input: Optional[Dict[str, Tensor]] = None, encoder_mask: Optional[Tensor] = None) -> Tensor: + if encoder_input is None: + return self.model(tokens, encoder_mask=encoder_mask) + return self.model(tokens, encoder_input=encoder_input, encoder_mask=encoder_mask) + + def setup_caches(self, max_batch_size, dtype=None): + self.model.setup_caches(max_batch_size, dtype=dtype) + + def reset_caches(self): + self.model.reset_caches() - model, state_dict = load_model_and_state_dict(gguf_path, **kwargs) - if state_dict != {}: - model.load_state_dict(state_dict, assign=True) - return model +MODEL_TYPE_TO_CLASS = { + ModelType.TextOnly: TextOnlyModel, + ModelType.Flamingo: FlamingoModel, +} class Transformer(nn.Module): def __init__(self, config: TransformerArgs) -> None: diff --git a/torchchat/utils/gguf_loader.py b/torchchat/utils/gguf_loader.py index 6e87a8a9f..923622edb 100644 --- a/torchchat/utils/gguf_loader.py +++ b/torchchat/utils/gguf_loader.py @@ -47,7 +47,7 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str: result = copy.deepcopy(gguf_name) for gguf_string, replacement in _name_replacements: result = result.replace(gguf_string, replacement) - result = "text_transformer." + result + result = "model." + result return result From 59337a66f4dd9650dcc365698a6bbcb9382201b6 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 9 Sep 2024 16:36:28 -0700 Subject: [PATCH 20/43] update exportation variable name --- torchchat/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchchat/export.py b/torchchat/export.py index d6c9c39c8..db3507ed7 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -58,7 +58,7 @@ def export_for_server( seq = Dim("seq", min=1, max=model.config.transformer_args["text"].max_seq_length) # Specify that the first dimension of each input is that batch size - dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}} + dynamic_shapes = {"tokens": {1: seq}, "input_pos": {0: seq}} else: input = ( torch.tensor([[1]], dtype=torch.int, device=device), From 68e29bb3d9a22ab2e65134963c50d0dfb479fb56 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 9 Sep 2024 18:26:54 -0700 Subject: [PATCH 21/43] remove redunctant function --- torchchat/model.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index e9701d9bf..8bebf8e6e 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -275,19 +275,6 @@ def forward(self, *args, **kwargs): @abstractmethod def setup_caches(self, *args, **kwargs): pass - - def _load_state_dict_hook(self, state_dict, prefix, *args, **kwargs): - """Apply extra prefix to the state_dict keys. - - Args: - state_dict (dict): The state dictionary where the model parameters are stored. - prefix (str): The prefix to add to each key in the state_dict. - """ - new_state_dict = {} - for key in state_dict.keys(): - new_key = f"{prefix}{key}" - new_state_dict[new_key] = state_dict[key] - return new_state_dict @classmethod def _get_model_instance(cls, config: ModelArgs): From 8ea29e7745a10aef8a017927cd80bcca68cfdff3 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 00:16:20 -0700 Subject: [PATCH 22/43] 1/n torchtune 3.1 8b --- torchchat/cli/builder.py | 89 ++++++++++--------- torchchat/generate.py | 17 ++-- torchchat/model.py | 36 ++++++-- torchchat/model_params/Meta-Llama-3.1-8B.json | 16 +++- 4 files changed, 103 insertions(+), 55 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index da2c66b14..8a386a129 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -139,7 +139,6 @@ def from_args(cls, args): # -> BuilderArgs: if "chat" in path_basename or "instruct" in path_basename: is_chat_model = True - output_pte_path = getattr(args, "output_pte_path", None) output_dso_path = getattr(args, "output_dso_path", None) if output_pte_path and args.dtype.startswith("fast"): @@ -328,45 +327,51 @@ def _load_model_default(builder_args, only_config=False): assert not builder_args.gguf_path model = _init_model_on_meta_device(builder_args) - # checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True) - cps = [] - if builder_args.checkpoint_dir is not None: - # Load multiple checkpoint; ignore the single path. - builder_args.checkpoint_path = None - for i in range(4): - cp_name = f"consolidated.{i}.pth" - print(f"Loading {cp_name}") - cps.append( - torch.load( - os.path.join(builder_args.checkpoint_dir, cp_name), - map_location=builder_args.device, - mmap=True, - ) - ) - checkpoint = {} - for key in cps[0].keys(): - if not torch.allclose(cps[0][key], cps[1][key]): - values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key]) - if key.endswith("wo.weight") or key.endswith("w2.weight"): - checkpoint[key] = torch.cat(values, dim=1) - else: - checkpoint[key] = torch.cat(values, dim=0) - else: - checkpoint[key] = cps[0][key] - else: - checkpoint = torch.load( - str(builder_args.checkpoint_path), - map_location=builder_args.device, - mmap=True, - weights_only=True, - ) - - if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): - checkpoint = checkpoint["model"] - - checkpoint = {"model." + k: v for k, v in checkpoint.items()} - - model.load_state_dict(checkpoint, assign=True, strict=True) + hf_checkpoint = torch.load( + str(builder_args.checkpoint_path), mmap=True, weights_only=True + ) + from torchtune.models.convert_weights import meta_to_tune + + tune_checkpoint = meta_to_tune(hf_checkpoint) + + # cps = [] + # if builder_args.checkpoint_dir is not None: + # # Load multiple checkpoint; ignore the single path. + # builder_args.checkpoint_path = None + # for i in range(4): + # cp_name = f"consolidated.{i}.pth" + # print(f"Loading {cp_name}") + # cps.append( + # torch.load( + # os.path.join(builder_args.checkpoint_dir, cp_name), + # map_location=builder_args.device, + # mmap=True, + # ) + # ) + # checkpoint = {} + # for key in cps[0].keys(): + # if not torch.allclose(cps[0][key], cps[1][key]): + # values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key]) + # if key.endswith("wo.weight") or key.endswith("w2.weight"): + # checkpoint[key] = torch.cat(values, dim=1) + # else: + # checkpoint[key] = torch.cat(values, dim=0) + # else: + # checkpoint[key] = cps[0][key] + # else: + # checkpoint = torch.load( + # str(builder_args.checkpoint_path), + # map_location=builder_args.device, + # mmap=True, + # weights_only=True, + # ) + + # if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): + # checkpoint = checkpoint["model"] + + tune_checkpoint = {"model." + k: v for k, v in tune_checkpoint.items()} + + model.load_state_dict(tune_checkpoint, assign=True, strict=True) return model @@ -534,7 +539,9 @@ def _initialize_model( if builder_args.setup_caches: with torch.device(builder_args.device): model.setup_caches( - max_batch_size=1, max_seq_length=max_seq_length or model.config.transformer_args["text"].max_seq_length + max_batch_size=1, + max_seq_length=max_seq_length + or model.config.transformer_args["text"].max_seq_length, ) model.to(dtype=builder_args.precision) diff --git a/torchchat/generate.py b/torchchat/generate.py index 1d9114d67..e2742720c 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -263,7 +263,7 @@ def __init__( else: self.draft_model = None - self.tokenizer_args.validate_model(self.model) + # self.tokenizer_args.validate_model(self.model) self.tokenizer_args.validate_model(self.draft_model, "draft model") generator_args.validate_build(self.builder_args) generator_args.validate_build(self.speculative_builder_args, "draft model") @@ -508,6 +508,9 @@ def generate( is_speculative = draft_model is not None device, dtype = prompt.device, prompt.dtype + print(f"Generating {max_new_tokens} tokens on device {device} with dtype {dtype}") + + # create an empty tensor of the expected final shape and # fill in the current tokens T = prompt.size(0) @@ -517,7 +520,8 @@ def generate( if start_pos == 0: model = model.to(device=device) with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + # model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + model.setup_caches(max_batch_size=1, dtype=torch.bfloat16) if is_speculative and draft_model is not model: draft_model.setup_caches( max_batch_size=1, max_seq_length=max_seq_length @@ -698,10 +702,11 @@ def chat( self.system_prompt = input("What is your system prompt? \n") else: - max_seq_length = min( - encoded.size(0) + generator_args.max_new_tokens, - self.model.config.transformer_args["text"].block_size, - ) + # max_seq_length = min( + # encoded.size(0) + generator_args.max_new_tokens, + # ) + max_seq_length = self.model.config.transformer_args["text"]["max_seq_len"] + # max_seq_length = 4096 max_seq_length = ( max_seq_length + self.speculative_builder_args.speculate_k + 1 diff --git a/torchchat/model.py b/torchchat/model.py index 8bebf8e6e..4ef4b67f1 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -34,6 +34,7 @@ try: from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder from torchtune.modules.model_fusion import DeepFusionModel + from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder except: pass @@ -46,6 +47,7 @@ def identity(**kwargs): class ModelType(Enum): TextOnly = "text_only" + Llama3_1 = "llama3_1" Flamingo = "flamingo" @@ -62,6 +64,15 @@ def text_only(cls): modules={'text': Transformer}, fusion_class=identity, ) + + @classmethod + def llama3_1(cls): + return cls( + model_type=ModelType.Llama3_1, + modules={'text': llama3_1_builder}, + fusion_class=identity, + ) + @classmethod def flamingo(cls): return cls( @@ -79,6 +90,8 @@ def get_recipe(cls, model_type): return cls.text_only() elif model_type == ModelType.Flamingo: return cls.flamingo() + elif model_type == ModelType.Llama3_1: + return cls.llama3_1() else: raise ValueError(f"Can not find the model recipe for {model_type}") @@ -170,9 +183,6 @@ def from_params(cls, params_path): except TypeError: # try to interpret as a dict of transformer configs model_type = ModelType(loaded_params["model_type"]) - - # now only support flamingo model - assert model_type == ModelType.Flamingo transformer_args = {k: v for k, v in loaded_params.items() if k != "model_type"} return cls(transformer_args, model_type) @@ -280,7 +290,7 @@ def setup_caches(self, *args, **kwargs): def _get_model_instance(cls, config: ModelArgs): model_class = MODEL_TYPE_TO_CLASS.get(config.model_type) if model_class is None: - raise ValueError("Unsupported model type") + raise ValueError("Unsupported model type:", str(config.model_type)) return model_class(config) @classmethod @@ -301,20 +311,31 @@ def from_gguf(cls, gguf_path: str, **kwargs): class TextOnlyModel(Model): - def forward(self, tokens: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) -> Tensor: + def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: return self.model(tokens, input_pos) def setup_caches(self, max_batch_size, max_seq_length): self.model.setup_caches(max_batch_size, max_seq_length) +class Llama31Model(Model): + def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + return self.model(tokens=tokens, input_pos=input_pos) + + def setup_caches(self, max_batch_size, dtype): + self.model.setup_caches(max_batch_size, dtype=dtype) + + def reset_caches(self): + self.model.reset_caches() + + class FlamingoModel(Model): - def forward(self, tokens: Optional[Tensor] = None, encoder_input: Optional[Dict[str, Tensor]] = None, encoder_mask: Optional[Tensor] = None) -> Tensor: + def forward(self, tokens: Tensor, encoder_input: Optional[Dict[str, Tensor]] = None, encoder_mask: Optional[Tensor] = None) -> Tensor: if encoder_input is None: return self.model(tokens, encoder_mask=encoder_mask) return self.model(tokens, encoder_input=encoder_input, encoder_mask=encoder_mask) - def setup_caches(self, max_batch_size, dtype=None): + def setup_caches(self, max_batch_size, dtype): self.model.setup_caches(max_batch_size, dtype=dtype) def reset_caches(self): @@ -324,6 +345,7 @@ def reset_caches(self): MODEL_TYPE_TO_CLASS = { ModelType.TextOnly: TextOnlyModel, ModelType.Flamingo: FlamingoModel, + ModelType.Llama3_1: Llama31Model, } class Transformer(nn.Module): diff --git a/torchchat/model_params/Meta-Llama-3.1-8B.json b/torchchat/model_params/Meta-Llama-3.1-8B.json index 0d3808205..893a80a04 100644 --- a/torchchat/model_params/Meta-Llama-3.1-8B.json +++ b/torchchat/model_params/Meta-Llama-3.1-8B.json @@ -1 +1,15 @@ -{"dim": 4096, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "n_heads": 32, "n_local_heads": 8, "n_layers": 32, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "use_scaled_rope": true} +{ + "model_type": "llama3_1", + "text": { + "vocab_size": 128256, + "num_layers": 32, + "num_heads": 32, + "num_kv_heads": 8, + "embed_dim": 4096, + "max_seq_len": 131072, + "intermediate_dim": 14336, + "attn_dropout": 0.0, + "norm_eps": 1e-5, + "rope_base": 500000.0 + } +} From 4a6f70398a154a557dd198c33a1c9adef128f0a8 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 00:17:29 -0700 Subject: [PATCH 23/43] installation update --- install/install_requirements.sh | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/install/install_requirements.sh b/install/install_requirements.sh index 7174fffa4..06c77549e 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -41,13 +41,19 @@ fi ) # Since torchchat often uses main-branch features of pytorch, only the nightly -# pip versions will have the required features. The NIGHTLY_VERSION value should +# pip versions will have the required features. The PYTORCH_NIGHTLY_VERSION value should # agree with the third-party/pytorch pinned submodule commit. # # NOTE: If a newly-fetched version of the executorch repo changes the value of -# NIGHTLY_VERSION, you should re-run this script to install the necessary +# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary # package versions. -NIGHTLY_VERSION=dev20240814 +PYTORCH_NIGHTLY_VERSION=dev20240814 + +# Nightly version for torchvision +VISION_NIGHTLY_VERSION=dev20240814 + +# Nightly version for torchao +AO_NIGHTLY_VERSION=dev20240905 # Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same ( @@ -67,10 +73,12 @@ fi # pip packages needed by exir. REQUIREMENTS_TO_INSTALL=( - torch=="2.5.0.${NIGHTLY_VERSION}" + torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}" + torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}" + torchao=="0.5.0.${AO_NIGHTLY_VERSION}" ) -# Install the requirements. `--extra-index-url` tells pip to look for package +# Install the requirements. --extra-index-url tells pip to look for package # versions on the provided URL if they aren't available on the default URL. ( set -x @@ -78,12 +86,12 @@ REQUIREMENTS_TO_INSTALL=( "${REQUIREMENTS_TO_INSTALL[@]}" ) -# For torchao need to install from github since nightly build doesn't have macos build. -# TODO: Remove this and install nightly build, once it supports macos +# Install torchtune from github to get the latest feature ( set -x - $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@e11201a62669f582d81cdb33e031a07fb8dfc4f3 + $PIP_EXECUTABLE install git+https://github.com/pytorch/torchtune.git ) + if [[ -x "$(command -v nvidia-smi)" ]]; then ( set -x From 5ec0811cd0ecfc4f60ee62f58cb40c18a97ab439 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 13:08:56 -0700 Subject: [PATCH 24/43] torchtune 3.1 8b / 30b --- torchchat/cli/builder.py | 4 ++++ torchchat/generate.py | 5 +---- torchchat/model.py | 2 +- torchchat/model_params/Meta-Llama-3.1-70B.json | 16 +++++++++++++++- torchchat/model_params/Meta-Llama-3.1-8B.json | 2 +- 5 files changed, 22 insertions(+), 7 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 8a386a129..0684a100a 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -35,6 +35,9 @@ from torchchat.utils.measure_time import measure_time from torchchat.utils.quantize import quantize_model +from torchtune.training import set_default_dtype + + @dataclass class BuilderArgs: @@ -372,6 +375,7 @@ def _load_model_default(builder_args, only_config=False): tune_checkpoint = {"model." + k: v for k, v in tune_checkpoint.items()} model.load_state_dict(tune_checkpoint, assign=True, strict=True) + return model diff --git a/torchchat/generate.py b/torchchat/generate.py index e2742720c..578177f21 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -506,10 +506,7 @@ def generate( torch.manual_seed(seed) is_speculative = draft_model is not None - device, dtype = prompt.device, prompt.dtype - - print(f"Generating {max_new_tokens} tokens on device {device} with dtype {dtype}") - + device, dtype = prompt.device, prompt.dtype # create an empty tensor of the expected final shape and # fill in the current tokens diff --git a/torchchat/model.py b/torchchat/model.py index 4ef4b67f1..389bba1c1 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -303,7 +303,7 @@ def from_table(cls, name: str): @classmethod def from_params(cls, params_path: str): - return cls._get_model_instance(ModelArgs.from_params(name)) + return cls._get_model_instance(ModelArgs.from_params(params_path)) @classmethod def from_gguf(cls, gguf_path: str, **kwargs): diff --git a/torchchat/model_params/Meta-Llama-3.1-70B.json b/torchchat/model_params/Meta-Llama-3.1-70B.json index d3e9a73fa..c59961c63 100644 --- a/torchchat/model_params/Meta-Llama-3.1-70B.json +++ b/torchchat/model_params/Meta-Llama-3.1-70B.json @@ -1 +1,15 @@ -{"dim": 8192, "ffn_dim_multiplier": 1.3, "multiple_of": 4096, "n_heads": 64, "n_local_heads": 8, "n_layers": 80, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "use_scaled_rope": true} +{ + "model_type": "llama3_1", + "text": { + "vocab_size": 128256, + "num_layers": 80, + "num_heads": 64, + "num_kv_heads": 8, + "embed_dim": 8192, + "max_seq_len": 8192, + "intermediate_dim": 28672, + "attn_dropout": 0.0, + "norm_eps": 1e-5, + "rope_base": 500000.0 + } +} diff --git a/torchchat/model_params/Meta-Llama-3.1-8B.json b/torchchat/model_params/Meta-Llama-3.1-8B.json index 893a80a04..e9ded77bd 100644 --- a/torchchat/model_params/Meta-Llama-3.1-8B.json +++ b/torchchat/model_params/Meta-Llama-3.1-8B.json @@ -6,7 +6,7 @@ "num_heads": 32, "num_kv_heads": 8, "embed_dim": 4096, - "max_seq_len": 131072, + "max_seq_len": 8192, "intermediate_dim": 14336, "attn_dropout": 0.0, "norm_eps": 1e-5, From f83154ac3cfb2b438b1eda53454785f386830101 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 14:26:18 -0700 Subject: [PATCH 25/43] bring torchchat llama3.1 back --- torchchat/cli/builder.py | 85 ++++++++++--------- torchchat/generate.py | 26 ++++-- torchchat/model_config/models.json | 12 +++ .../model_params/Meta-Llama-3.1-70B-Tune.json | 15 ++++ .../model_params/Meta-Llama-3.1-70B.json | 16 +--- .../model_params/Meta-Llama-3.1-8B-Tune.json | 15 ++++ torchchat/model_params/Meta-Llama-3.1-8B.json | 16 +--- 7 files changed, 105 insertions(+), 80 deletions(-) create mode 100644 torchchat/model_params/Meta-Llama-3.1-70B-Tune.json create mode 100644 torchchat/model_params/Meta-Llama-3.1-8B-Tune.json diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 0684a100a..04ce4fb2d 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -36,6 +36,7 @@ from torchchat.utils.quantize import quantize_model from torchtune.training import set_default_dtype +from torchtune.models.convert_weights import meta_to_tune @@ -333,48 +334,48 @@ def _load_model_default(builder_args, only_config=False): hf_checkpoint = torch.load( str(builder_args.checkpoint_path), mmap=True, weights_only=True ) - from torchtune.models.convert_weights import meta_to_tune - - tune_checkpoint = meta_to_tune(hf_checkpoint) - - # cps = [] - # if builder_args.checkpoint_dir is not None: - # # Load multiple checkpoint; ignore the single path. - # builder_args.checkpoint_path = None - # for i in range(4): - # cp_name = f"consolidated.{i}.pth" - # print(f"Loading {cp_name}") - # cps.append( - # torch.load( - # os.path.join(builder_args.checkpoint_dir, cp_name), - # map_location=builder_args.device, - # mmap=True, - # ) - # ) - # checkpoint = {} - # for key in cps[0].keys(): - # if not torch.allclose(cps[0][key], cps[1][key]): - # values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key]) - # if key.endswith("wo.weight") or key.endswith("w2.weight"): - # checkpoint[key] = torch.cat(values, dim=1) - # else: - # checkpoint[key] = torch.cat(values, dim=0) - # else: - # checkpoint[key] = cps[0][key] - # else: - # checkpoint = torch.load( - # str(builder_args.checkpoint_path), - # map_location=builder_args.device, - # mmap=True, - # weights_only=True, - # ) - - # if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): - # checkpoint = checkpoint["model"] - - tune_checkpoint = {"model." + k: v for k, v in tune_checkpoint.items()} - - model.load_state_dict(tune_checkpoint, assign=True, strict=True) + + cps = [] + if builder_args.params_table.endswith("Tune"): + print("Loading Tune checkpoint") + checkpoint = meta_to_tune(hf_checkpoint) + elif builder_args.checkpoint_dir is not None: + # Load multiple checkpoint; ignore the single path. + builder_args.checkpoint_path = None + for i in range(4): + cp_name = f"consolidated.{i}.pth" + print(f"Loading {cp_name}") + cps.append( + torch.load( + os.path.join(builder_args.checkpoint_dir, cp_name), + map_location=builder_args.device, + mmap=True, + ) + ) + checkpoint = {} + for key in cps[0].keys(): + if not torch.allclose(cps[0][key], cps[1][key]): + values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key]) + if key.endswith("wo.weight") or key.endswith("w2.weight"): + checkpoint[key] = torch.cat(values, dim=1) + else: + checkpoint[key] = torch.cat(values, dim=0) + else: + checkpoint[key] = cps[0][key] + else: + checkpoint = torch.load( + str(builder_args.checkpoint_path), + map_location=builder_args.device, + mmap=True, + weights_only=True, + ) + + if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): + checkpoint = checkpoint["model"] + + + checkpoint = {"model." + k: v for k, v in checkpoint.items()} + model.load_state_dict(checkpoint, assign=True, strict=True) return model diff --git a/torchchat/generate.py b/torchchat/generate.py index 578177f21..f194f4688 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -116,6 +116,7 @@ class GeneratorArgs: speculate_k: int = 5 sequential_prefill: bool = False max_autotune: bool = False + is_torchtune_model: bool = False def __post_init__(self): if self.compile_prefill and self.sequential_prefill: @@ -161,6 +162,7 @@ def from_args(cls, args): speculate_k=args.speculate_k, sequential_prefill=sequential_prefill, max_autotune=args.max_autotune, + is_torchtune_model=args.model.endswith("tune"), ) @@ -197,6 +199,7 @@ def __init__( self.profile = profile self.quantize = quantize self.draft_quantize = draft_quantize + self.is_torchtune_model = generator_args.is_torchtune_model # global print # from tp import maybe_init_dist @@ -517,8 +520,10 @@ def generate( if start_pos == 0: model = model.to(device=device) with torch.device(device): - # model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - model.setup_caches(max_batch_size=1, dtype=torch.bfloat16) + if self.is_torchtune_model: + model.setup_caches(max_batch_size=1, dtype=model.dtype) + else: + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) if is_speculative and draft_model is not model: draft_model.setup_caches( max_batch_size=1, max_seq_length=max_seq_length @@ -687,7 +692,12 @@ def chat( self.system_prompt = None # Set up our max_seq_length - if generator_args.chat_mode: + + # This is a hack to get around the fact that different models have different ways to record their max_seq_length and might be wrong + # TODO: unify the max_seq_length config representation. + if generator_args.is_torchtune_model: + max_seq_length = self.model.config.transformer_args["text"]["max_seq_len"] + elif generator_args.chat_mode: max_seq_length = self.model.config.transformer_args["text"].max_seq_length print( f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye" @@ -699,11 +709,11 @@ def chat( self.system_prompt = input("What is your system prompt? \n") else: - # max_seq_length = min( - # encoded.size(0) + generator_args.max_new_tokens, - # ) - max_seq_length = self.model.config.transformer_args["text"]["max_seq_len"] - # max_seq_length = 4096 + max_seq_length = min( + encoded.size(0) + generator_args.max_new_tokens, + self.model.config.transformer_args["text"].block_size, + ) + max_seq_length = ( max_seq_length + self.speculative_builder_args.speculate_k + 1 diff --git a/torchchat/model_config/models.json b/torchchat/model_config/models.json index ab0abb7d6..ca8c5acdf 100644 --- a/torchchat/model_config/models.json +++ b/torchchat/model_config/models.json @@ -57,6 +57,18 @@ "distribution_path": "meta-llama/Meta-Llama-3.1-70B-Instruct", "transformer_params_key": "Meta-Llama-3.1-70B" }, + "meta-llama/Meta-Llama-3.1-8B-Instruct-Tune": { + "aliases": ["llama3.1-tune", "llama3.1-chat-tune", "llama3.1-instruct-tune"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "transformer_params_key": "Meta-Llama-3.1-8B-Tune" + }, + "meta-llama/Meta-Llama-3.1-70B-Instruct-Tune": { + "aliases": ["llama3.1-70b-tune"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "meta-llama/Meta-Llama-3.1-70B-Instruct", + "transformer_params_key": "Meta-Llama-3.1-70B-Tune" + }, "meta-llama/CodeLlama-7b-Python-hf": { "aliases": ["codellama", "codellama-7b"], "distribution_channel": "HuggingFaceSnapshot", diff --git a/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json b/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json new file mode 100644 index 000000000..c59961c63 --- /dev/null +++ b/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json @@ -0,0 +1,15 @@ +{ + "model_type": "llama3_1", + "text": { + "vocab_size": 128256, + "num_layers": 80, + "num_heads": 64, + "num_kv_heads": 8, + "embed_dim": 8192, + "max_seq_len": 8192, + "intermediate_dim": 28672, + "attn_dropout": 0.0, + "norm_eps": 1e-5, + "rope_base": 500000.0 + } +} diff --git a/torchchat/model_params/Meta-Llama-3.1-70B.json b/torchchat/model_params/Meta-Llama-3.1-70B.json index c59961c63..d3e9a73fa 100644 --- a/torchchat/model_params/Meta-Llama-3.1-70B.json +++ b/torchchat/model_params/Meta-Llama-3.1-70B.json @@ -1,15 +1 @@ -{ - "model_type": "llama3_1", - "text": { - "vocab_size": 128256, - "num_layers": 80, - "num_heads": 64, - "num_kv_heads": 8, - "embed_dim": 8192, - "max_seq_len": 8192, - "intermediate_dim": 28672, - "attn_dropout": 0.0, - "norm_eps": 1e-5, - "rope_base": 500000.0 - } -} +{"dim": 8192, "ffn_dim_multiplier": 1.3, "multiple_of": 4096, "n_heads": 64, "n_local_heads": 8, "n_layers": 80, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "use_scaled_rope": true} diff --git a/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json b/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json new file mode 100644 index 000000000..e9ded77bd --- /dev/null +++ b/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json @@ -0,0 +1,15 @@ +{ + "model_type": "llama3_1", + "text": { + "vocab_size": 128256, + "num_layers": 32, + "num_heads": 32, + "num_kv_heads": 8, + "embed_dim": 4096, + "max_seq_len": 8192, + "intermediate_dim": 14336, + "attn_dropout": 0.0, + "norm_eps": 1e-5, + "rope_base": 500000.0 + } +} diff --git a/torchchat/model_params/Meta-Llama-3.1-8B.json b/torchchat/model_params/Meta-Llama-3.1-8B.json index e9ded77bd..0d3808205 100644 --- a/torchchat/model_params/Meta-Llama-3.1-8B.json +++ b/torchchat/model_params/Meta-Llama-3.1-8B.json @@ -1,15 +1 @@ -{ - "model_type": "llama3_1", - "text": { - "vocab_size": 128256, - "num_layers": 32, - "num_heads": 32, - "num_kv_heads": 8, - "embed_dim": 4096, - "max_seq_len": 8192, - "intermediate_dim": 14336, - "attn_dropout": 0.0, - "norm_eps": 1e-5, - "rope_base": 500000.0 - } -} +{"dim": 4096, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "n_heads": 32, "n_local_heads": 8, "n_layers": 32, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "use_scaled_rope": true} From f891fb119ce06c78495d500cd391c6c625f4d421 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 14:32:32 -0700 Subject: [PATCH 26/43] bring tok vali back to torchchat model + revert install_requirements.sh --- install/install_requirements.sh | 24 ++++++++---------------- torchchat/generate.py | 5 ++++- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/install/install_requirements.sh b/install/install_requirements.sh index 06c77549e..7174fffa4 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -41,19 +41,13 @@ fi ) # Since torchchat often uses main-branch features of pytorch, only the nightly -# pip versions will have the required features. The PYTORCH_NIGHTLY_VERSION value should +# pip versions will have the required features. The NIGHTLY_VERSION value should # agree with the third-party/pytorch pinned submodule commit. # # NOTE: If a newly-fetched version of the executorch repo changes the value of -# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary +# NIGHTLY_VERSION, you should re-run this script to install the necessary # package versions. -PYTORCH_NIGHTLY_VERSION=dev20240814 - -# Nightly version for torchvision -VISION_NIGHTLY_VERSION=dev20240814 - -# Nightly version for torchao -AO_NIGHTLY_VERSION=dev20240905 +NIGHTLY_VERSION=dev20240814 # Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same ( @@ -73,12 +67,10 @@ fi # pip packages needed by exir. REQUIREMENTS_TO_INSTALL=( - torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}" - torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}" - torchao=="0.5.0.${AO_NIGHTLY_VERSION}" + torch=="2.5.0.${NIGHTLY_VERSION}" ) -# Install the requirements. --extra-index-url tells pip to look for package +# Install the requirements. `--extra-index-url` tells pip to look for package # versions on the provided URL if they aren't available on the default URL. ( set -x @@ -86,12 +78,12 @@ REQUIREMENTS_TO_INSTALL=( "${REQUIREMENTS_TO_INSTALL[@]}" ) -# Install torchtune from github to get the latest feature +# For torchao need to install from github since nightly build doesn't have macos build. +# TODO: Remove this and install nightly build, once it supports macos ( set -x - $PIP_EXECUTABLE install git+https://github.com/pytorch/torchtune.git + $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@e11201a62669f582d81cdb33e031a07fb8dfc4f3 ) - if [[ -x "$(command -v nvidia-smi)" ]]; then ( set -x diff --git a/torchchat/generate.py b/torchchat/generate.py index f194f4688..f6a0f52b5 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -266,7 +266,10 @@ def __init__( else: self.draft_model = None - # self.tokenizer_args.validate_model(self.model) + # torchtune model does not contain essential info for validation + # TODO: refactor model config to be more generic + if not self.is_torchtune_model: + self.tokenizer_args.validate_model(self.model) self.tokenizer_args.validate_model(self.draft_model, "draft model") generator_args.validate_build(self.builder_args) generator_args.validate_build(self.speculative_builder_args, "draft model") From 6c97eb7600658378ce17301752500951309a2705 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 14:43:45 -0700 Subject: [PATCH 27/43] solve bugs related to tt model support --- torchchat/generate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index f6a0f52b5..80db44887 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -200,6 +200,7 @@ def __init__( self.quantize = quantize self.draft_quantize = draft_quantize self.is_torchtune_model = generator_args.is_torchtune_model + self.dtype = builder_args.precision # global print # from tp import maybe_init_dist @@ -524,7 +525,7 @@ def generate( model = model.to(device=device) with torch.device(device): if self.is_torchtune_model: - model.setup_caches(max_batch_size=1, dtype=model.dtype) + model.setup_caches(max_batch_size=1, dtype=self.dtype) else: model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) if is_speculative and draft_model is not model: From 11217a4966376dbdc94cdb4136a2b79b9ad27601 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 14:51:15 -0700 Subject: [PATCH 28/43] bypass torchtune import issue --- torchchat/cli/builder.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 04ce4fb2d..43055b1e7 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -35,8 +35,12 @@ from torchchat.utils.measure_time import measure_time from torchchat.utils.quantize import quantize_model -from torchtune.training import set_default_dtype -from torchtune.models.convert_weights import meta_to_tune +# bypass the import issue before torchao is ready on macos +try: + from torchtune.training import set_default_dtype + from torchtune.models.convert_weights import meta_to_tune +except: + pass From 758af1085b86f9c0af8f287a33b67a158acd7418 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 18:18:35 -0700 Subject: [PATCH 29/43] solve Jack's wonderful comments --- torchchat/model.py | 46 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index 8ea83ead6..257decaac 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -42,22 +42,41 @@ class ModelType(Enum): TextOnly = "text_only" Flamingo = "flamingo" +# Type for objects that can generate nn.Module instance +ModuleLike = Union[nn.Module, Callable[..., nn.Module]] @dataclass class ModelRecipe: + """ + A class in TorchChat that describes and contains all supported model structures in TorchChat. + + ModelRecipe represents a model as a collection of Transformer modules and a fusion module, + providing a standardized and centralized way to define and build models in TorchChat. + Attributes: + model_type (ModelType): + The type of the model. + modules (Dict[str, ModuleLike]): + A dictionary of ModuleLike modules, where each key is the module name and each + value is a ModuleLike object that generates the transformer. + The names of the Transformer modules should match the corresponding names in the + fusion class and the JSON file holding model hyperparameters. + fusion_class (ModuleLike): + A ModuleLike object that generates a fusion module by taking the constructed modules above. + """ + model_type: ModelType - modules: dict - fusion_class: torch.nn.Module + modules: Dict[str, ModuleLike] + fusion_class: ModuleLike @classmethod - def text_only(cls): + def _text_only(cls): return cls( model_type=ModelType.TextOnly, modules={'text_transformer': Transformer}, fusion_class=nn.Identity, ) @classmethod - def flamingo(cls): + def _flamingo(cls): return cls( model_type=ModelType.Flamingo, modules={ @@ -70,9 +89,9 @@ def flamingo(cls): @classmethod def get_recipe(cls, model_type): if model_type == ModelType.TextOnly: - return cls.text_only() + return cls._text_only() elif model_type == ModelType.Flamingo: - return cls.flamingo() + return cls._flamingo() else: raise ValueError(f"Can not find the model recipe for {model_type}") @@ -139,6 +158,7 @@ def __init__( self.model_type = model_type if isinstance(transformer_args, TransformerArgs): + assert model_type == ModelType.TextOnly self.transformer_args = {"text": transformer_args} else: self.transformer_args = transformer_args @@ -165,7 +185,7 @@ def from_params(cls, params_path): # try to interpret as a dict of transformer configs model_type = ModelType(loaded_params["model_type"]) - # now only support flamingo model + # Currently only supporting flamingo model assert model_type == ModelType.Flamingo transformer_args = {k: v for k, v in loaded_params.items() if k != "model_type"} @@ -247,6 +267,9 @@ def update(self, input_pos, k_val, v_val): class Model(nn.Module): + """ + The entrance for model construction in tochchat. + """ def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config @@ -256,7 +279,14 @@ def __init__(self, config: ModelArgs) -> None: else: self.model = self.build_model() - def build_model(self): + def build_model(self) -> nn.Module: + """ + Builds a model based on the provided configuration. + This method retrieves a ModelRecipe instance corresponding to the specified model type, + constructs the required Transformer modules, and combines them using the fusion class. + Returns: + The constructed model instance. + """ recipe = ModelRecipe.get_recipe(self.config.model_type) modules = {} for name, module_class in recipe.modules.items(): From 80b54816bc3baad2b8f31e248ed32c9cb8bcaa79 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 18:21:10 -0700 Subject: [PATCH 30/43] remveo extra dot --- dist_run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dist_run.py b/dist_run.py index 2d415eaa2..ee9d159b2 100644 --- a/dist_run.py +++ b/dist_run.py @@ -122,7 +122,7 @@ def main(): gpu_memory_monitor = GPUMemoryMonitor("cuda") logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") - config = ModelArgs.from_name(MODEL_NAME)..transformer_args['text'] + config = ModelArgs.from_name(MODEL_NAME).transformer_args['text'] logger.info(f"Chat Model Config: {config}") tokenizer = _build_chat_tokenizer() From 1cc79091720f2dafa184799589de8819689c7895 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 18:26:04 -0700 Subject: [PATCH 31/43] add type.Callable --- torchchat/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchchat/model.py b/torchchat/model.py index 257decaac..bd30d2e5f 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import Dict, Optional, Union +from typing import Callable, Dict, Optional, Union import torch import torch.nn as nn From 95684d94acd723c71fabaec5cb10c375dfe2ec14 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 18:31:30 -0700 Subject: [PATCH 32/43] fix torchchat typos --- torchchat/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index bd30d2e5f..500f2c71c 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -48,10 +48,10 @@ class ModelType(Enum): @dataclass class ModelRecipe: """ - A class in TorchChat that describes and contains all supported model structures in TorchChat. + The class describes and contains all supported model structures in torchchat. ModelRecipe represents a model as a collection of Transformer modules and a fusion module, - providing a standardized and centralized way to define and build models in TorchChat. + providing a standardized and centralized way to define and build models in torchchat. Attributes: model_type (ModelType): The type of the model. @@ -268,7 +268,7 @@ def update(self, input_pos, k_val, v_val): class Model(nn.Module): """ - The entrance for model construction in tochchat. + The entrance for model construction in torchchat. """ def __init__(self, config: ModelArgs) -> None: super().__init__() From 6dc2aabf1c94d2b0801131c696f950af9752caf0 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 18:38:13 -0700 Subject: [PATCH 33/43] solve bug when args.model is None --- torchchat/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 80db44887..182488652 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -162,7 +162,7 @@ def from_args(cls, args): speculate_k=args.speculate_k, sequential_prefill=sequential_prefill, max_autotune=args.max_autotune, - is_torchtune_model=args.model.endswith("tune"), + is_torchtune_model=args.model and args.model.endswith("tune"), ) From 08a05b766ccb83594423675d167375e7d4671853 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 18:58:29 -0700 Subject: [PATCH 34/43] support builder_args.params_table is None --- torchchat/cli/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 43055b1e7..a867d1e85 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -340,7 +340,7 @@ def _load_model_default(builder_args, only_config=False): ) cps = [] - if builder_args.params_table.endswith("Tune"): + if builder_args.params_table and builder_args.params_table.endswith("Tune"): print("Loading Tune checkpoint") checkpoint = meta_to_tune(hf_checkpoint) elif builder_args.checkpoint_dir is not None: From 257b1ceb0e0552f6ade3bc7de0d8691f348e1004 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 19:11:23 -0700 Subject: [PATCH 35/43] remove all .DS_Store --- .DS_Store | Bin 6148 -> 0 bytes build/.DS_Store | Bin 6148 -> 0 bytes torchchat/model_params/.DS_Store | Bin 6148 -> 0 bytes 3 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .DS_Store delete mode 100644 build/.DS_Store delete mode 100644 torchchat/model_params/.DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 16d875f048b003f89438f610b00ecfc05d69253a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeH~J&pn~427Thk&tL8DbsL(y+MTF1YBUnPJ=WO#fm;h=h<(?J6KYLuU74ZjM<6h0 J5P^Rs@C1L66M6su diff --git a/build/.DS_Store b/build/.DS_Store deleted file mode 100644 index 658fefb065d303740d88fb70cd071aca7d48bab7..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKyH3ME5S)b+k!T_+%KHWWz>2~b@BzF`g@X|hqIJc0@oCIH8q3f`qKRguy}8@l zxzkPI^#ZW%M}G}$0Icbb`0`kSWhr|)s-4^PA6I1WybNdYM!1*Cu!kOD_5P^ENzar6lyN&zWwstWk`q0t??!YMI6 z9UP(sAg&k=<2-r^V)Fp8E1VJ;p;=OiNwsP*Ea{B5%IgZJ#H7Qj`LMd#szb53o#(em zhjob>rGONeD{!96jo1Gh`XBxOoTQZ$kOC*AfUS17yDgtowRQ10ueFVSPxqW}x*O*~ o;SlAR80DA?FUMDrlzGkP-0upf#Go@CbfSI+To;)X_-_S%08E`2H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0 Date: Tue, 10 Sep 2024 19:13:57 -0700 Subject: [PATCH 36/43] bring gguf back --- torchchat/model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchchat/model.py b/torchchat/model.py index 8c719c713..5af4c21e5 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -337,7 +337,12 @@ def from_params(cls, params_path: str): @classmethod def from_gguf(cls, gguf_path: str, **kwargs): - return cls._get_model_instance(ModelArgs.from_gguf(name)) + from torchchat.utils.gguf_loader import load_model_and_state_dict + + model, state_dict = load_model_and_state_dict(gguf_path, **kwargs) + if state_dict != {}: + model.load_state_dict(state_dict, assign=True) + return model class TextOnlyModel(Model): From 192841d7e1ef97a88544476428a9026658ea1b69 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 22:30:56 -0700 Subject: [PATCH 37/43] remove reduntant updates --- torchchat/cli/builder.py | 9 ++++----- torchchat/generate.py | 3 +-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index a867d1e85..57ed4f96b 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -335,14 +335,14 @@ def _load_model_default(builder_args, only_config=False): assert not builder_args.gguf_path model = _init_model_on_meta_device(builder_args) - hf_checkpoint = torch.load( - str(builder_args.checkpoint_path), mmap=True, weights_only=True - ) cps = [] if builder_args.params_table and builder_args.params_table.endswith("Tune"): print("Loading Tune checkpoint") - checkpoint = meta_to_tune(hf_checkpoint) + meta_checkpoint = torch.load( + str(builder_args.checkpoint_path), mmap=True, weights_only=True + ) + checkpoint = meta_to_tune(meta_checkpoint) elif builder_args.checkpoint_dir is not None: # Load multiple checkpoint; ignore the single path. builder_args.checkpoint_path = None @@ -377,7 +377,6 @@ def _load_model_default(builder_args, only_config=False): if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): checkpoint = checkpoint["model"] - checkpoint = {"model." + k: v for k, v in checkpoint.items()} model.load_state_dict(checkpoint, assign=True, strict=True) diff --git a/torchchat/generate.py b/torchchat/generate.py index 182488652..722e5156c 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -513,7 +513,7 @@ def generate( torch.manual_seed(seed) is_speculative = draft_model is not None - device, dtype = prompt.device, prompt.dtype + device, dtype = prompt.device, prompt.dtype # create an empty tensor of the expected final shape and # fill in the current tokens @@ -718,7 +718,6 @@ def chat( self.model.config.transformer_args["text"].block_size, ) - max_seq_length = ( max_seq_length + self.speculative_builder_args.speculate_k + 1 if self.draft_model is not None From a5556f460f53cb1f5c63269b03e294939929136e Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 22:35:10 -0700 Subject: [PATCH 38/43] bring checkpoint back --- torchchat/cli/builder.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 57ed4f96b..674de9c9f 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -335,17 +335,15 @@ def _load_model_default(builder_args, only_config=False): assert not builder_args.gguf_path model = _init_model_on_meta_device(builder_args) + checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True) - cps = [] if builder_args.params_table and builder_args.params_table.endswith("Tune"): print("Loading Tune checkpoint") - meta_checkpoint = torch.load( - str(builder_args.checkpoint_path), mmap=True, weights_only=True - ) - checkpoint = meta_to_tune(meta_checkpoint) + checkpoint = meta_to_tune(checkpoint) elif builder_args.checkpoint_dir is not None: # Load multiple checkpoint; ignore the single path. builder_args.checkpoint_path = None + cps = [] for i in range(4): cp_name = f"consolidated.{i}.pth" print(f"Loading {cp_name}") From d395f7f8d6e2d783cd988f0903c8754efdd3bde0 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 23:19:27 -0700 Subject: [PATCH 39/43] debug --- torchchat/cli/builder.py | 5 +++-- torchchat/generate.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 674de9c9f..6589d4c72 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -147,6 +147,7 @@ def from_args(cls, args): # -> BuilderArgs: if "chat" in path_basename or "instruct" in path_basename: is_chat_model = True + output_pte_path = getattr(args, "output_pte_path", None) output_dso_path = getattr(args, "output_dso_path", None) if output_pte_path and args.dtype.startswith("fast"): @@ -335,11 +336,11 @@ def _load_model_default(builder_args, only_config=False): assert not builder_args.gguf_path model = _init_model_on_meta_device(builder_args) - checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True) if builder_args.params_table and builder_args.params_table.endswith("Tune"): print("Loading Tune checkpoint") - checkpoint = meta_to_tune(checkpoint) + meta_checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True) + checkpoint = meta_to_tune(meta_checkpoint) elif builder_args.checkpoint_dir is not None: # Load multiple checkpoint; ignore the single path. builder_args.checkpoint_path = None diff --git a/torchchat/generate.py b/torchchat/generate.py index 722e5156c..13daf145d 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -302,7 +302,8 @@ def sample( need_probs: bool, temperature: float = 1.0, top_k: Optional[int] = None, - ): + ): + print(f"logits {logits}") if temperature == 0 and not need_probs: _, idx_next = torch.topk(logits[0, -1], k=1, dim=-1) return (idx_next, None) @@ -513,7 +514,7 @@ def generate( torch.manual_seed(seed) is_speculative = draft_model is not None - device, dtype = prompt.device, prompt.dtype + device, dtype = prompt.device, prompt.dtype # create an empty tensor of the expected final shape and # fill in the current tokens From 8130901a5cafe45499d4fa4fceabe291d274dd30 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 23:21:26 -0700 Subject: [PATCH 40/43] debug --- torchchat/generate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchchat/generate.py b/torchchat/generate.py index 13daf145d..c5781b3ab 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -323,6 +323,7 @@ def prefill( # logging.debug(f"x: {x}, input_pos: {input_pos}") width = x.size(1) assert input_pos.size(0) == width + print("x: ", x, "input_pos: ", input_pos, "width: ", width) if sequential_prefill: for i in range(width): From dc4015261a03f9dad32e6df4a648f3c415b8d00b Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 10 Sep 2024 23:33:50 -0700 Subject: [PATCH 41/43] debug --- torchchat/generate.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchchat/generate.py b/torchchat/generate.py index c5781b3ab..256cce54d 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -324,15 +324,20 @@ def prefill( width = x.size(1) assert input_pos.size(0) == width print("x: ", x, "input_pos: ", input_pos, "width: ", width) + print("sequential_prefill: ", sequential_prefill) if sequential_prefill: for i in range(width): + print("i: ", i) x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1) + print("x_sliced: ", x_sliced, "ip_sliced: ", ip_sliced) # logging.debug(f" x: {x_sliced}, input_pos: {ip_sliced}") logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i]) + print("logits: ", logits) else: # input_pos: [B, S] logits = model(x, input_pos) + print("logits: ", logits) # print(f"logits {logits.shape}") # print(f"x: {x},\n input_pos: {input_pos}\n") From 6cf7db7859f7e6608d093028cfbc129ddb9ada77 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 11 Sep 2024 00:27:29 -0700 Subject: [PATCH 42/43] new factory func to produce Model from modelargs --- torchchat/generate.py | 7 ------- torchchat/model.py | 8 ++++++-- torchchat/utils/gguf_loader.py | 2 +- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 256cce54d..67e9b9ae8 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -303,7 +303,6 @@ def sample( temperature: float = 1.0, top_k: Optional[int] = None, ): - print(f"logits {logits}") if temperature == 0 and not need_probs: _, idx_next = torch.topk(logits[0, -1], k=1, dim=-1) return (idx_next, None) @@ -323,21 +322,15 @@ def prefill( # logging.debug(f"x: {x}, input_pos: {input_pos}") width = x.size(1) assert input_pos.size(0) == width - print("x: ", x, "input_pos: ", input_pos, "width: ", width) - print("sequential_prefill: ", sequential_prefill) if sequential_prefill: for i in range(width): - print("i: ", i) x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1) - print("x_sliced: ", x_sliced, "ip_sliced: ", ip_sliced) # logging.debug(f" x: {x_sliced}, input_pos: {ip_sliced}") logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i]) - print("logits: ", logits) else: # input_pos: [B, S] logits = model(x, input_pos) - print("logits: ", logits) # print(f"logits {logits.shape}") # print(f"x: {x},\n input_pos: {input_pos}\n") diff --git a/torchchat/model.py b/torchchat/model.py index 5af4c21e5..fd7a91359 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -310,11 +310,11 @@ def build_model(self) -> nn.Module: @abstractmethod def forward(self, *args, **kwargs): - pass + raise NotImplementedError("forward method is not implemented") @abstractmethod def setup_caches(self, *args, **kwargs): - pass + raise NotImplementedError("setup_caches method is not implemented") @classmethod def _get_model_instance(cls, config: ModelArgs): @@ -323,6 +323,10 @@ def _get_model_instance(cls, config: ModelArgs): raise ValueError("Unsupported model type:", str(config.model_type)) return model_class(config) + @classmethod + def from_model_args(cls, config: ModelArgs): + return cls._get_model_instance(config) + @classmethod def from_name(cls, name: str): return cls._get_model_instance(ModelArgs.from_name(name)) diff --git a/torchchat/utils/gguf_loader.py b/torchchat/utils/gguf_loader.py index 923622edb..c7b931dae 100644 --- a/torchchat/utils/gguf_loader.py +++ b/torchchat/utils/gguf_loader.py @@ -558,7 +558,7 @@ def load_model(gguf_file: str) -> torch.nn.Module: # metadata.get(f"{arch}.rope.dimension_count", None) with torch.device("meta"): - model = Model(model_args) + model = Model.from_model_args(model_args) return model From 1599c2b6c18ddd30fab4430b9692323a54f403e0 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 11 Sep 2024 01:17:48 -0700 Subject: [PATCH 43/43] solve comments --- torchchat/cli/builder.py | 1 - torchchat/model.py | 15 ++++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 6589d4c72..4f3f5727c 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -37,7 +37,6 @@ # bypass the import issue before torchao is ready on macos try: - from torchtune.training import set_default_dtype from torchtune.models.convert_weights import meta_to_tune except: pass diff --git a/torchchat/model.py b/torchchat/model.py index fd7a91359..f0910f54a 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -281,7 +281,7 @@ def update(self, input_pos, k_val, v_val): return k_out, v_out -class Model(nn.Module): +class Model(ABC, nn.Module): """ The entrance for model construction in torchchat. """ @@ -301,10 +301,10 @@ def build_model(self) -> nn.Module: recipe = ModelRecipe.get_recipe(self.config.model_type) modules = {} for name, module_class in recipe.modules.items(): - if isinstance(self.config.transformer_args[name], dict): - modules[name] = module_class(**self.config.transformer_args[name]) + if isinstance(config_args := self.config.transformer_args[name], dict): + modules[name] = module_class(**config_args) else: - modules[name] = module_class(self.config.transformer_args[name]) + modules[name] = module_class(config_args) return recipe.fusion_class(**modules) @@ -369,7 +369,12 @@ def reset_caches(self): class FlamingoModel(Model): - def forward(self, tokens: Tensor, encoder_input: Optional[Dict[str, Tensor]] = None, encoder_mask: Optional[Tensor] = None) -> Tensor: + def forward( + self, + tokens: Tensor, + encoder_input: Optional[Dict[str, Tensor]] = None, + encoder_mask: Optional[Tensor] = None, + ) -> Tensor: if encoder_input is None: return self.model(tokens, encoder_mask=encoder_mask) return self.model(tokens, encoder_input=encoder_input, encoder_mask=encoder_mask)