Skip to content

Llama3.1 with torchtune #1123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 53 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
0dae9ef
added model source and type for torchtune flamingo support
Gasoonjia Aug 27, 2024
87397e3
added model source and type for torchtune flamingo support
Gasoonjia Aug 27, 2024
0f61614
grab missing enum
Gasoonjia Aug 27, 2024
d7f3a88
fix ModelArgs init
Gasoonjia Aug 27, 2024
994b148
create init func for ModelArgs for BC
Gasoonjia Aug 28, 2024
d184e68
update pipeline for ModleSource and ModelType
Gasoonjia Aug 28, 2024
6bb2485
Merge branch 'main' of github.com:pytorch/torchchat into flamingo_com…
Gasoonjia Aug 28, 2024
0d8e368
revert lintrunner update on ET
Gasoonjia Aug 28, 2024
6c78850
introduce flamingo modules form torchtune
Gasoonjia Aug 28, 2024
2691bae
back up to move to linux
Gasoonjia Aug 28, 2024
ba960f0
mitigate building issue
Gasoonjia Aug 29, 2024
8b3a684
pass local test
Gasoonjia Aug 30, 2024
880dfe2
merge solved
Gasoonjia Aug 30, 2024
e7fa7b4
structual model builder
Gasoonjia Sep 3, 2024
c179bcb
update torchtune address
Gasoonjia Sep 5, 2024
5ead73b
update install requirement
Gasoonjia Sep 6, 2024
882c336
support new torchtune flamingo component
Gasoonjia Sep 6, 2024
952b8bd
specific version for vision and ao
Gasoonjia Sep 6, 2024
e764111
unify text-only model generation pipeline
Gasoonjia Sep 9, 2024
9679a5b
convert installation back and bypass torchtune
Gasoonjia Sep 9, 2024
56006ea
Merge branch 'main' into flamingo_component
Gasoonjia Sep 9, 2024
2ec217d
Merge branch 'flamingo_component' into llama3.1_with_torchtune
Gasoonjia Sep 9, 2024
a3f08ea
restructual model definition
Gasoonjia Sep 9, 2024
59337a6
update exportation variable name
Gasoonjia Sep 9, 2024
33da35b
Merge branch 'flamingo_component' into llama3.1_with_torchtune
Gasoonjia Sep 9, 2024
68e29bb
remove redunctant function
Gasoonjia Sep 10, 2024
8ea29e7
1/n torchtune 3.1 8b
Gasoonjia Sep 10, 2024
4a6f703
installation update
Gasoonjia Sep 10, 2024
5ec0811
torchtune 3.1 8b / 30b
Gasoonjia Sep 10, 2024
3043433
merge main
Gasoonjia Sep 10, 2024
f83154a
bring torchchat llama3.1 back
Gasoonjia Sep 10, 2024
f891fb1
bring tok vali back to torchchat model + revert install_requirements.sh
Gasoonjia Sep 10, 2024
6c97eb7
solve bugs related to tt model support
Gasoonjia Sep 10, 2024
11217a4
bypass torchtune import issue
Gasoonjia Sep 10, 2024
d0e2974
solve merge confilct
Gasoonjia Sep 10, 2024
758af10
solve Jack's wonderful comments
Gasoonjia Sep 11, 2024
80b5481
remveo extra dot
Gasoonjia Sep 11, 2024
2b8c939
merge into flamingo_component
Gasoonjia Sep 11, 2024
1cc7909
add type.Callable
Gasoonjia Sep 11, 2024
95684d9
fix torchchat typos
Gasoonjia Sep 11, 2024
c750c08
merge with flamingo_component
Gasoonjia Sep 11, 2024
6dc2aab
solve bug when args.model is None
Gasoonjia Sep 11, 2024
08a05b7
support builder_args.params_table is None
Gasoonjia Sep 11, 2024
257b1ce
remove all .DS_Store
Gasoonjia Sep 11, 2024
324d338
bring gguf back
Gasoonjia Sep 11, 2024
5082fb2
merge main
Gasoonjia Sep 11, 2024
192841d
remove reduntant updates
Gasoonjia Sep 11, 2024
a5556f4
bring checkpoint back
Gasoonjia Sep 11, 2024
d395f7f
debug
Gasoonjia Sep 11, 2024
8130901
debug
Gasoonjia Sep 11, 2024
dc40152
debug
Gasoonjia Sep 11, 2024
6cf7db7
new factory func to produce Model from modelargs
Gasoonjia Sep 11, 2024
1599c2b
solve comments
Gasoonjia Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
from torchchat.utils.measure_time import measure_time
from torchchat.utils.quantize import quantize_model

# bypass the import issue before torchao is ready on macos
try:
from torchtune.training import set_default_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused import?

from torchtune.models.convert_weights import meta_to_tune
except:
pass



@dataclass
class BuilderArgs:
Expand Down Expand Up @@ -328,11 +336,15 @@ def _load_model_default(builder_args, only_config=False):
assert not builder_args.gguf_path

model = _init_model_on_meta_device(builder_args)
# checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
cps = []
if builder_args.checkpoint_dir is not None:

if builder_args.params_table and builder_args.params_table.endswith("Tune"):
print("Loading Tune checkpoint")
meta_checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
checkpoint = meta_to_tune(meta_checkpoint)
elif builder_args.checkpoint_dir is not None:
# Load multiple checkpoint; ignore the single path.
builder_args.checkpoint_path = None
cps = []
for i in range(4):
cp_name = f"consolidated.{i}.pth"
print(f"Loading {cp_name}")
Expand Down Expand Up @@ -363,10 +375,10 @@ def _load_model_default(builder_args, only_config=False):

if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
checkpoint = checkpoint["model"]

checkpoint = {"text_transformer." + k: v for k, v in checkpoint.items()}

checkpoint = {"model." + k: v for k, v in checkpoint.items()}
model.load_state_dict(checkpoint, assign=True, strict=True)

return model


Expand Down Expand Up @@ -534,7 +546,9 @@ def _initialize_model(
if builder_args.setup_caches:
with torch.device(builder_args.device):
model.setup_caches(
max_batch_size=1, max_seq_length=max_seq_length or model.config.transformer_args["text"].max_seq_length
max_batch_size=1,
max_seq_length=max_seq_length
or model.config.transformer_args["text"].max_seq_length,
)

model.to(dtype=builder_args.precision)
Expand Down
23 changes: 19 additions & 4 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class GeneratorArgs:
speculate_k: int = 5
sequential_prefill: bool = False
max_autotune: bool = False
is_torchtune_model: bool = False

def __post_init__(self):
if self.compile_prefill and self.sequential_prefill:
Expand Down Expand Up @@ -161,6 +162,7 @@ def from_args(cls, args):
speculate_k=args.speculate_k,
sequential_prefill=sequential_prefill,
max_autotune=args.max_autotune,
is_torchtune_model=args.model and args.model.endswith("tune"),
)


Expand Down Expand Up @@ -197,6 +199,8 @@ def __init__(
self.profile = profile
self.quantize = quantize
self.draft_quantize = draft_quantize
self.is_torchtune_model = generator_args.is_torchtune_model
self.dtype = builder_args.precision

# global print
# from tp import maybe_init_dist
Expand Down Expand Up @@ -263,7 +267,10 @@ def __init__(
else:
self.draft_model = None

self.tokenizer_args.validate_model(self.model)
# torchtune model does not contain essential info for validation
# TODO: refactor model config to be more generic
if not self.is_torchtune_model:
self.tokenizer_args.validate_model(self.model)
self.tokenizer_args.validate_model(self.draft_model, "draft model")
generator_args.validate_build(self.builder_args)
generator_args.validate_build(self.speculative_builder_args, "draft model")
Expand Down Expand Up @@ -295,7 +302,7 @@ def sample(
need_probs: bool,
temperature: float = 1.0,
top_k: Optional[int] = None,
):
):
if temperature == 0 and not need_probs:
_, idx_next = torch.topk(logits[0, -1], k=1, dim=-1)
return (idx_next, None)
Expand Down Expand Up @@ -517,7 +524,10 @@ def generate(
if start_pos == 0:
model = model.to(device=device)
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
if self.is_torchtune_model:
model.setup_caches(max_batch_size=1, dtype=self.dtype)
else:
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
if is_speculative and draft_model is not model:
draft_model.setup_caches(
max_batch_size=1, max_seq_length=max_seq_length
Expand Down Expand Up @@ -686,7 +696,12 @@ def chat(

self.system_prompt = None
# Set up our max_seq_length
if generator_args.chat_mode:

# This is a hack to get around the fact that different models have different ways to record their max_seq_length and might be wrong
# TODO: unify the max_seq_length config representation.
if generator_args.is_torchtune_model:
max_seq_length = self.model.config.transformer_args["text"]["max_seq_len"]
elif generator_args.chat_mode:
max_seq_length = self.model.config.transformer_args["text"].max_seq_length
print(
f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye"
Expand Down
128 changes: 85 additions & 43 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from enum import Enum
from pathlib import Path
from typing import Callable, Dict, Optional, Union
from abc import ABC, abstractmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ABC is unused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be one of the Model's parents. Fixed it.


import torch
import torch.nn as nn
Expand All @@ -33,13 +34,20 @@
try:
from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder
from torchtune.modules.model_fusion import DeepFusionModel
from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder
except:
pass

config_path = Path(f"{str(Path(__file__).parent)}/model_params")

def identity(**kwargs):
if len(kwargs) != 1:
raise ValueError("Only one argument is expected")
return list(kwargs.values())[0]

class ModelType(Enum):
TextOnly = "text_only"
Llama3_1 = "llama3_1"
Flamingo = "flamingo"

# Type for objects that can generate nn.Module instance
Expand Down Expand Up @@ -72,9 +80,18 @@ class ModelRecipe:
def _text_only(cls):
return cls(
model_type=ModelType.TextOnly,
modules={'text_transformer': Transformer},
fusion_class=nn.Identity,
modules={'text': Transformer},
fusion_class=identity,
)

@classmethod
def _llama3_1(cls):
return cls(
model_type=ModelType.Llama3_1,
modules={'text': llama3_1_builder},
fusion_class=identity,
)

@classmethod
def _flamingo(cls):
return cls(
Expand All @@ -92,6 +109,8 @@ def get_recipe(cls, model_type):
return cls._text_only()
elif model_type == ModelType.Flamingo:
return cls._flamingo()
elif model_type == ModelType.Llama3_1:
return cls._llama3_1()
else:
raise ValueError(f"Can not find the model recipe for {model_type}")

Expand Down Expand Up @@ -184,11 +203,7 @@ def from_params(cls, params_path):
except TypeError:
# try to interpret as a dict of transformer configs
model_type = ModelType(loaded_params["model_type"])

# Currently only supporting flamingo model
assert model_type == ModelType.Flamingo
transformer_args = {k: v for k, v in loaded_params.items() if k != "model_type"}

return cls(transformer_args, model_type)

@classmethod
Expand Down Expand Up @@ -273,11 +288,7 @@ class Model(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
# TODO: unify the model init logic
if config.model_type == ModelType.TextOnly:
self.text_transformer = Transformer(config.transformer_args["text"])
else:
self.model = self.build_model()
self.model = self.build_model()

def build_model(self) -> nn.Module:
"""
Expand All @@ -290,50 +301,43 @@ def build_model(self) -> nn.Module:
recipe = ModelRecipe.get_recipe(self.config.model_type)
modules = {}
for name, module_class in recipe.modules.items():
modules[name] = module_class(**self.config.transformer_args[name])

if isinstance(self.config.transformer_args[name], dict):
modules[name] = module_class(**self.config.transformer_args[name])
else:
modules[name] = module_class(self.config.transformer_args[name])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(self.config.transformer_args[name], dict):
modules[name] = module_class(**self.config.transformer_args[name])
else:
modules[name] = module_class(self.config.transformer_args[name])
if isinstance(config_args := self.config.transformer_args[name], dict):
modules[name] = module_class(**config_args)
else:
modules[name] = module_class(config_args)


return recipe.fusion_class(**modules)

@abstractmethod
def forward(self, *args, **kwargs):
raise NotImplementedError("forward method is not implemented")

def forward(self,
tokens: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
encoder_input: Optional[Dict[str, Tensor]] = None,
encoder_mask: Optional[Tensor] = None) -> Tensor:
@abstractmethod
def setup_caches(self, *args, **kwargs):
raise NotImplementedError("setup_caches method is not implemented")

if self.config.model_type == ModelType.TextOnly:
return self.text_transformer(tokens, input_pos)
else:
assert self.config.model_type == ModelType.Flamingo
if input_pos:
warnings.warn("input_pos is not used for Flamingo model. Ignoring it.")
if encoder_input is None:
return self.model(tokens, encoder_mask = encoder_mask)
return self.model(tokens, encoder_input=encoder_input, encoder_mask = encoder_mask)

def setup_caches(self, max_batch_size, max_seq_length=None, dtype=None):
if self.config.model_type == ModelType.TextOnly:
self.text_transformer.setup_caches(max_batch_size, max_seq_length)
else:
assert self.config.model_type == ModelType.Flamingo
if max_seq_length is not None:
warnings.warn("max_seq_length is not used for Flamingo model. Ignoring it.")
self.model.setup_caches(max_batch_size, dtype=dtype)

def reset_caches(self):
assert self.config.model_type == ModelType.Flamingo
self.model.reset_caches()
@classmethod
def _get_model_instance(cls, config: ModelArgs):
model_class = MODEL_TYPE_TO_CLASS.get(config.model_type)
if model_class is None:
raise ValueError("Unsupported model type:", str(config.model_type))
return model_class(config)

@classmethod
def from_model_args(cls, config: ModelArgs):
return cls._get_model_instance(config)

@classmethod
def from_name(cls, name: str):
return cls(ModelArgs.from_name(name))
return cls._get_model_instance(ModelArgs.from_name(name))

@classmethod
def from_table(cls, name: str):
return cls(ModelArgs.from_table(name))
return cls._get_model_instance(ModelArgs.from_table(name))

@classmethod
def from_params(cls, params_path: str):
return cls(ModelArgs.from_params(params_path))
return cls._get_model_instance(ModelArgs.from_params(params_path))

@classmethod
def from_gguf(cls, gguf_path: str, **kwargs):
Expand All @@ -345,6 +349,44 @@ def from_gguf(cls, gguf_path: str, **kwargs):
return model


class TextOnlyModel(Model):
def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
return self.model(tokens, input_pos)

def setup_caches(self, max_batch_size, max_seq_length):
self.model.setup_caches(max_batch_size, max_seq_length)


class Llama31Model(Model):
def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
return self.model(tokens=tokens, input_pos=input_pos)

def setup_caches(self, max_batch_size, dtype):
self.model.setup_caches(max_batch_size, dtype=dtype)

def reset_caches(self):
self.model.reset_caches()


class FlamingoModel(Model):
def forward(self, tokens: Tensor, encoder_input: Optional[Dict[str, Tensor]] = None, encoder_mask: Optional[Tensor] = None) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lint long line

if encoder_input is None:
return self.model(tokens, encoder_mask=encoder_mask)
return self.model(tokens, encoder_input=encoder_input, encoder_mask=encoder_mask)

def setup_caches(self, max_batch_size, dtype):
self.model.setup_caches(max_batch_size, dtype=dtype)

def reset_caches(self):
self.model.reset_caches()


MODEL_TYPE_TO_CLASS = {
ModelType.TextOnly: TextOnlyModel,
ModelType.Flamingo: FlamingoModel,
ModelType.Llama3_1: Llama31Model,
}

class Transformer(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
super().__init__()
Expand Down
12 changes: 12 additions & 0 deletions torchchat/model_config/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@
"distribution_path": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"transformer_params_key": "Meta-Llama-3.1-70B"
},
"meta-llama/Meta-Llama-3.1-8B-Instruct-Tune": {
"aliases": ["llama3.1-tune", "llama3.1-chat-tune", "llama3.1-instruct-tune"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"transformer_params_key": "Meta-Llama-3.1-8B-Tune"
},
"meta-llama/Meta-Llama-3.1-70B-Instruct-Tune": {
"aliases": ["llama3.1-70b-tune"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"transformer_params_key": "Meta-Llama-3.1-70B-Tune"
},
"meta-llama/CodeLlama-7b-Python-hf": {
"aliases": ["codellama", "codellama-7b"],
"distribution_channel": "HuggingFaceSnapshot",
Expand Down
15 changes: 15 additions & 0 deletions torchchat/model_params/Meta-Llama-3.1-70B-Tune.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"model_type": "llama3_1",
"text": {
"vocab_size": 128256,
"num_layers": 80,
"num_heads": 64,
"num_kv_heads": 8,
"embed_dim": 8192,
"max_seq_len": 8192,
"intermediate_dim": 28672,
"attn_dropout": 0.0,
"norm_eps": 1e-5,
"rope_base": 500000.0
}
}
15 changes: 15 additions & 0 deletions torchchat/model_params/Meta-Llama-3.1-8B-Tune.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"model_type": "llama3_1",
"text": {
"vocab_size": 128256,
"num_layers": 32,
"num_heads": 32,
"num_kv_heads": 8,
"embed_dim": 4096,
"max_seq_len": 8192,
"intermediate_dim": 14336,
"attn_dropout": 0.0,
"norm_eps": 1e-5,
"rope_base": 500000.0
}
}
Loading
Loading