-
Notifications
You must be signed in to change notification settings - Fork 250
Llama3.1 with torchtune #1123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama3.1 with torchtune #1123
Changes from 52 commits
0dae9ef
87397e3
0f61614
d7f3a88
994b148
d184e68
6bb2485
0d8e368
6c78850
2691bae
ba960f0
8b3a684
880dfe2
e7fa7b4
c179bcb
5ead73b
882c336
952b8bd
e764111
9679a5b
56006ea
2ec217d
a3f08ea
59337a6
33da35b
68e29bb
8ea29e7
4a6f703
5ec0811
3043433
f83154a
f891fb1
6c97eb7
11217a4
d0e2974
758af10
80b5481
2b8c939
1cc7909
95684d9
c750c08
6dc2aab
08a05b7
257b1ce
324d338
5082fb2
192841d
a5556f4
d395f7f
8130901
dc40152
6cf7db7
1599c2b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -11,6 +11,7 @@ | |||||||||||||||||
from enum import Enum | ||||||||||||||||||
from pathlib import Path | ||||||||||||||||||
from typing import Callable, Dict, Optional, Union | ||||||||||||||||||
from abc import ABC, abstractmethod | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ABC is unused? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be one of the Model's parents. Fixed it. |
||||||||||||||||||
|
||||||||||||||||||
import torch | ||||||||||||||||||
import torch.nn as nn | ||||||||||||||||||
|
@@ -33,13 +34,20 @@ | |||||||||||||||||
try: | ||||||||||||||||||
from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder | ||||||||||||||||||
from torchtune.modules.model_fusion import DeepFusionModel | ||||||||||||||||||
from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder | ||||||||||||||||||
except: | ||||||||||||||||||
pass | ||||||||||||||||||
|
||||||||||||||||||
config_path = Path(f"{str(Path(__file__).parent)}/model_params") | ||||||||||||||||||
|
||||||||||||||||||
def identity(**kwargs): | ||||||||||||||||||
if len(kwargs) != 1: | ||||||||||||||||||
raise ValueError("Only one argument is expected") | ||||||||||||||||||
return list(kwargs.values())[0] | ||||||||||||||||||
|
||||||||||||||||||
class ModelType(Enum): | ||||||||||||||||||
TextOnly = "text_only" | ||||||||||||||||||
Llama3_1 = "llama3_1" | ||||||||||||||||||
Flamingo = "flamingo" | ||||||||||||||||||
|
||||||||||||||||||
# Type for objects that can generate nn.Module instance | ||||||||||||||||||
|
@@ -72,9 +80,18 @@ class ModelRecipe: | |||||||||||||||||
def _text_only(cls): | ||||||||||||||||||
return cls( | ||||||||||||||||||
model_type=ModelType.TextOnly, | ||||||||||||||||||
modules={'text_transformer': Transformer}, | ||||||||||||||||||
fusion_class=nn.Identity, | ||||||||||||||||||
modules={'text': Transformer}, | ||||||||||||||||||
fusion_class=identity, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
@classmethod | ||||||||||||||||||
def _llama3_1(cls): | ||||||||||||||||||
return cls( | ||||||||||||||||||
model_type=ModelType.Llama3_1, | ||||||||||||||||||
modules={'text': llama3_1_builder}, | ||||||||||||||||||
fusion_class=identity, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
@classmethod | ||||||||||||||||||
def _flamingo(cls): | ||||||||||||||||||
return cls( | ||||||||||||||||||
|
@@ -92,6 +109,8 @@ def get_recipe(cls, model_type): | |||||||||||||||||
return cls._text_only() | ||||||||||||||||||
elif model_type == ModelType.Flamingo: | ||||||||||||||||||
return cls._flamingo() | ||||||||||||||||||
elif model_type == ModelType.Llama3_1: | ||||||||||||||||||
return cls._llama3_1() | ||||||||||||||||||
else: | ||||||||||||||||||
raise ValueError(f"Can not find the model recipe for {model_type}") | ||||||||||||||||||
|
||||||||||||||||||
|
@@ -184,11 +203,7 @@ def from_params(cls, params_path): | |||||||||||||||||
except TypeError: | ||||||||||||||||||
# try to interpret as a dict of transformer configs | ||||||||||||||||||
model_type = ModelType(loaded_params["model_type"]) | ||||||||||||||||||
|
||||||||||||||||||
# Currently only supporting flamingo model | ||||||||||||||||||
assert model_type == ModelType.Flamingo | ||||||||||||||||||
transformer_args = {k: v for k, v in loaded_params.items() if k != "model_type"} | ||||||||||||||||||
|
||||||||||||||||||
return cls(transformer_args, model_type) | ||||||||||||||||||
|
||||||||||||||||||
@classmethod | ||||||||||||||||||
|
@@ -273,11 +288,7 @@ class Model(nn.Module): | |||||||||||||||||
def __init__(self, config: ModelArgs) -> None: | ||||||||||||||||||
super().__init__() | ||||||||||||||||||
self.config = config | ||||||||||||||||||
# TODO: unify the model init logic | ||||||||||||||||||
if config.model_type == ModelType.TextOnly: | ||||||||||||||||||
self.text_transformer = Transformer(config.transformer_args["text"]) | ||||||||||||||||||
else: | ||||||||||||||||||
self.model = self.build_model() | ||||||||||||||||||
self.model = self.build_model() | ||||||||||||||||||
|
||||||||||||||||||
def build_model(self) -> nn.Module: | ||||||||||||||||||
""" | ||||||||||||||||||
|
@@ -290,50 +301,43 @@ def build_model(self) -> nn.Module: | |||||||||||||||||
recipe = ModelRecipe.get_recipe(self.config.model_type) | ||||||||||||||||||
modules = {} | ||||||||||||||||||
for name, module_class in recipe.modules.items(): | ||||||||||||||||||
modules[name] = module_class(**self.config.transformer_args[name]) | ||||||||||||||||||
|
||||||||||||||||||
if isinstance(self.config.transformer_args[name], dict): | ||||||||||||||||||
modules[name] = module_class(**self.config.transformer_args[name]) | ||||||||||||||||||
else: | ||||||||||||||||||
modules[name] = module_class(self.config.transformer_args[name]) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
|
||||||||||||||||||
return recipe.fusion_class(**modules) | ||||||||||||||||||
|
||||||||||||||||||
@abstractmethod | ||||||||||||||||||
def forward(self, *args, **kwargs): | ||||||||||||||||||
raise NotImplementedError("forward method is not implemented") | ||||||||||||||||||
|
||||||||||||||||||
def forward(self, | ||||||||||||||||||
tokens: Optional[Tensor] = None, | ||||||||||||||||||
input_pos: Optional[Tensor] = None, | ||||||||||||||||||
encoder_input: Optional[Dict[str, Tensor]] = None, | ||||||||||||||||||
encoder_mask: Optional[Tensor] = None) -> Tensor: | ||||||||||||||||||
@abstractmethod | ||||||||||||||||||
def setup_caches(self, *args, **kwargs): | ||||||||||||||||||
raise NotImplementedError("setup_caches method is not implemented") | ||||||||||||||||||
|
||||||||||||||||||
if self.config.model_type == ModelType.TextOnly: | ||||||||||||||||||
return self.text_transformer(tokens, input_pos) | ||||||||||||||||||
else: | ||||||||||||||||||
assert self.config.model_type == ModelType.Flamingo | ||||||||||||||||||
if input_pos: | ||||||||||||||||||
warnings.warn("input_pos is not used for Flamingo model. Ignoring it.") | ||||||||||||||||||
if encoder_input is None: | ||||||||||||||||||
return self.model(tokens, encoder_mask = encoder_mask) | ||||||||||||||||||
return self.model(tokens, encoder_input=encoder_input, encoder_mask = encoder_mask) | ||||||||||||||||||
|
||||||||||||||||||
def setup_caches(self, max_batch_size, max_seq_length=None, dtype=None): | ||||||||||||||||||
if self.config.model_type == ModelType.TextOnly: | ||||||||||||||||||
self.text_transformer.setup_caches(max_batch_size, max_seq_length) | ||||||||||||||||||
else: | ||||||||||||||||||
assert self.config.model_type == ModelType.Flamingo | ||||||||||||||||||
if max_seq_length is not None: | ||||||||||||||||||
warnings.warn("max_seq_length is not used for Flamingo model. Ignoring it.") | ||||||||||||||||||
self.model.setup_caches(max_batch_size, dtype=dtype) | ||||||||||||||||||
|
||||||||||||||||||
def reset_caches(self): | ||||||||||||||||||
assert self.config.model_type == ModelType.Flamingo | ||||||||||||||||||
self.model.reset_caches() | ||||||||||||||||||
@classmethod | ||||||||||||||||||
def _get_model_instance(cls, config: ModelArgs): | ||||||||||||||||||
model_class = MODEL_TYPE_TO_CLASS.get(config.model_type) | ||||||||||||||||||
if model_class is None: | ||||||||||||||||||
raise ValueError("Unsupported model type:", str(config.model_type)) | ||||||||||||||||||
return model_class(config) | ||||||||||||||||||
|
||||||||||||||||||
@classmethod | ||||||||||||||||||
def from_model_args(cls, config: ModelArgs): | ||||||||||||||||||
return cls._get_model_instance(config) | ||||||||||||||||||
|
||||||||||||||||||
@classmethod | ||||||||||||||||||
def from_name(cls, name: str): | ||||||||||||||||||
return cls(ModelArgs.from_name(name)) | ||||||||||||||||||
return cls._get_model_instance(ModelArgs.from_name(name)) | ||||||||||||||||||
|
||||||||||||||||||
@classmethod | ||||||||||||||||||
def from_table(cls, name: str): | ||||||||||||||||||
return cls(ModelArgs.from_table(name)) | ||||||||||||||||||
return cls._get_model_instance(ModelArgs.from_table(name)) | ||||||||||||||||||
|
||||||||||||||||||
@classmethod | ||||||||||||||||||
def from_params(cls, params_path: str): | ||||||||||||||||||
return cls(ModelArgs.from_params(params_path)) | ||||||||||||||||||
return cls._get_model_instance(ModelArgs.from_params(params_path)) | ||||||||||||||||||
|
||||||||||||||||||
@classmethod | ||||||||||||||||||
def from_gguf(cls, gguf_path: str, **kwargs): | ||||||||||||||||||
|
@@ -345,6 +349,44 @@ def from_gguf(cls, gguf_path: str, **kwargs): | |||||||||||||||||
return model | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class TextOnlyModel(Model): | ||||||||||||||||||
def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: | ||||||||||||||||||
return self.model(tokens, input_pos) | ||||||||||||||||||
|
||||||||||||||||||
def setup_caches(self, max_batch_size, max_seq_length): | ||||||||||||||||||
self.model.setup_caches(max_batch_size, max_seq_length) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class Llama31Model(Model): | ||||||||||||||||||
def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: | ||||||||||||||||||
return self.model(tokens=tokens, input_pos=input_pos) | ||||||||||||||||||
|
||||||||||||||||||
def setup_caches(self, max_batch_size, dtype): | ||||||||||||||||||
self.model.setup_caches(max_batch_size, dtype=dtype) | ||||||||||||||||||
|
||||||||||||||||||
def reset_caches(self): | ||||||||||||||||||
self.model.reset_caches() | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class FlamingoModel(Model): | ||||||||||||||||||
def forward(self, tokens: Tensor, encoder_input: Optional[Dict[str, Tensor]] = None, encoder_mask: Optional[Tensor] = None) -> Tensor: | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lint long line |
||||||||||||||||||
if encoder_input is None: | ||||||||||||||||||
return self.model(tokens, encoder_mask=encoder_mask) | ||||||||||||||||||
return self.model(tokens, encoder_input=encoder_input, encoder_mask=encoder_mask) | ||||||||||||||||||
|
||||||||||||||||||
def setup_caches(self, max_batch_size, dtype): | ||||||||||||||||||
self.model.setup_caches(max_batch_size, dtype=dtype) | ||||||||||||||||||
|
||||||||||||||||||
def reset_caches(self): | ||||||||||||||||||
self.model.reset_caches() | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
MODEL_TYPE_TO_CLASS = { | ||||||||||||||||||
ModelType.TextOnly: TextOnlyModel, | ||||||||||||||||||
ModelType.Flamingo: FlamingoModel, | ||||||||||||||||||
ModelType.Llama3_1: Llama31Model, | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
class Transformer(nn.Module): | ||||||||||||||||||
def __init__(self, config: TransformerArgs) -> None: | ||||||||||||||||||
super().__init__() | ||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"model_type": "llama3_1", | ||
"text": { | ||
"vocab_size": 128256, | ||
"num_layers": 80, | ||
"num_heads": 64, | ||
"num_kv_heads": 8, | ||
"embed_dim": 8192, | ||
"max_seq_len": 8192, | ||
"intermediate_dim": 28672, | ||
"attn_dropout": 0.0, | ||
"norm_eps": 1e-5, | ||
"rope_base": 500000.0 | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"model_type": "llama3_1", | ||
"text": { | ||
"vocab_size": 128256, | ||
"num_layers": 32, | ||
"num_heads": 32, | ||
"num_kv_heads": 8, | ||
"embed_dim": 4096, | ||
"max_seq_len": 8192, | ||
"intermediate_dim": 14336, | ||
"attn_dropout": 0.0, | ||
"norm_eps": 1e-5, | ||
"rope_base": 500000.0 | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused import?