Skip to content

Commit 0844a5b

Browse files
authored
[Feature] Optimize config saving. (#8490)
* support not saving training config. * fix print. * add llm meta config to run pretrain. * fix optionals. * delete duplicate keys. * tmp * change tensor_parallel_output to default to true. * move out fuse_attention_qkv, fuse_attention_ffn * add to finetune. * delete print. * tensor_parallel_output = true * fix * fix * fix * fix to dict. * bug fix. * add register_nonsaveable_keys. * apply register_nonsaveable_keys. * fix ci. * fix qwen2 * fix all. * fix. * fix ci.
1 parent 7bf31a3 commit 0844a5b

37 files changed

+339
-459
lines changed

llm/argument.py

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass, field
15-
from typing import List, Optional
15+
from typing import Optional
1616

1717
from paddlenlp.trainer import TrainingArguments
1818
from paddlenlp.trainer.trainer_utils import IntervalStrategy
@@ -105,30 +105,17 @@ class ModelArgument:
105105
model_name_or_path: str = field(
106106
default=None, metadata={"help": "Build-in pretrained model name or the path to local model."}
107107
)
108-
use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"})
109108
tokenizer_name_or_path: Optional[str] = field(
110109
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
111110
)
112-
use_fused_rms_norm: bool = field(
113-
default=False,
114-
metadata={"help": "llama or other model, use_fused_rms_norm"},
115-
)
116111
fuse_attention_qkv: bool = field(
117-
default=False,
112+
default=None,
118113
metadata={"help": "whether to fuse attention qkv"},
119114
)
120115
fuse_attention_ffn: bool = field(
121-
default=False,
116+
default=None,
122117
metadata={"help": "whether to fuse first up and gate proj in mlp block"},
123118
)
124-
recompute_granularity: str = field(
125-
default="full",
126-
metadata={"help": "Choose among ['full', 'core_attn', 'full_attn']"},
127-
)
128-
virtual_pp_degree: int = field(
129-
default=1,
130-
metadata={"help": "virtual_pp_degree"},
131-
)
132119
hidden_dropout_prob: float = field(default=0.1, metadata={"help": "The hidden dropout prob."})
133120
attention_probs_dropout_prob: float = field(default=0.1, metadata={"help": "The attention hidden dropout prob."})
134121

@@ -138,32 +125,6 @@ class ModelArgument:
138125
"help": "Pre-training from existing paddlenlp model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddlenlp models."
139126
},
140127
)
141-
sequence_parallel: bool = field(
142-
default=False,
143-
metadata={"help": "whether to use sequence parallel"},
144-
)
145-
fuse_sequence_parallel_allreduce: bool = field(
146-
default=False,
147-
metadata={"help": "whether to use fuse sequence parallel allreduce"},
148-
)
149-
use_fused_rope: Optional[bool] = field(
150-
default=False,
151-
metadata={"help": "Enable rope fusion or not."},
152-
)
153-
no_recompute_layers: Optional[List[int]] = field(
154-
default=None,
155-
metadata={"help": "Specify the full transformer layers that should not be recomputed."},
156-
)
157-
pp_recompute_interval: int = field(
158-
default=1,
159-
metadata={
160-
"help": "The interval for the number of layers at which recomputation occurs. A value of 0 indicates no recomputation. Default is 0."
161-
},
162-
)
163-
recompute_use_reentrant: bool = field(
164-
default=False,
165-
metadata={"help": "recompute_use_reentrant"},
166-
)
167128
weight_quantize_algo: str = field(
168129
default=None,
169130
metadata={

llm/finetune_generation.py

Lines changed: 28 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,12 @@
4949
from paddlenlp.transformers import (
5050
AutoConfig,
5151
AutoModelForCausalLM,
52+
AutoModelForCausalLMPipe,
5253
AutoTokenizer,
5354
Llama3Tokenizer,
5455
LlamaTokenizer,
5556
)
57+
from paddlenlp.transformers.configuration_utils import LlmMetaConfig, llmmetaclass
5658
from paddlenlp.utils.log import logger
5759

5860
# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
@@ -68,6 +70,7 @@ def docstring_decorator(fn):
6870

6971

7072
@dataclass
73+
@llmmetaclass
7174
@add_start_docstrings(TrainingArguments.__doc__)
7275
class FinetuneArguments(TrainingArguments):
7376
decay_steps: int = field(
@@ -146,65 +149,45 @@ def main():
146149

147150
model_config = AutoConfig.from_pretrained(
148151
model_args.model_name_or_path,
149-
tensor_parallel_output=training_args.tensor_parallel_output,
150-
tensor_parallel_degree=training_args.tensor_parallel_degree,
151-
tensor_parallel_rank=training_args.tensor_parallel_rank,
152152
dtype=dtype,
153153
from_aistudio=model_args.from_aistudio,
154154
quantization_config=quantization_config,
155155
)
156-
if hasattr(model_config, "use_flash_attention"):
157-
model_config.use_flash_attention = model_args.use_flash_attention
158-
159-
model_config.use_fused_rms_norm = model_args.use_fused_rms_norm
160-
model_config.fuse_attention_qkv = model_args.fuse_attention_qkv
161-
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn
162-
model_config.recompute_granularity = model_args.recompute_granularity
163-
model_config.virtual_pp_degree = model_args.virtual_pp_degree
164-
model_config.sequence_parallel = model_args.sequence_parallel
165-
model_config.fuse_sequence_parallel_allreduce = model_args.fuse_sequence_parallel_allreduce
166-
model_config.use_fused_rope = model_args.use_fused_rope
167-
168-
model_config.no_recompute_layers = model_args.no_recompute_layers
169-
model_config.pp_recompute_interval = model_args.pp_recompute_interval
170-
model_config.recompute_use_reentrant = model_args.recompute_use_reentrant
171-
model_config.use_recompute = training_args.recompute
172-
173-
model_config.tensor_parallel_degree = training_args.tensor_parallel_degree
174-
model_config.tensor_parallel_rank = training_args.tensor_parallel_rank
156+
157+
LlmMetaConfig.set_llm_config(model_config, training_args)
175158

176159
# Config for model using dropout, such as GPT.
177-
model_config.hidden_dropout_prob = model_args.hidden_dropout_prob
178-
model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob
160+
if hasattr(model_config, "hidden_dropout_prob"):
161+
model_config.hidden_dropout_prob = model_args.hidden_dropout_prob
162+
if hasattr(model_config, "attention_probs_dropout_prob"):
163+
model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob
164+
165+
if model_args.fuse_attention_qkv is not None:
166+
model_config.fuse_attention_qkv = model_args.fuse_attention_qkv
167+
if model_args.fuse_attention_ffn is not None:
168+
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn
179169

180-
model_config.sep_parallel_degree = training_args.sep_parallel_degree
181-
model_config.tensor_parallel_output = training_args.tensor_parallel_output
182170
model_config.seq_length = data_args.max_length
183171

172+
print("Final model config:", model_config)
173+
174+
model_class = AutoModelForCausalLM
184175
if training_args.pipeline_parallel_degree > 1:
185176
if data_args.eval_with_do_generation and training_args.do_eval:
186177
raise ValueError("Plese set eval_with_do_generation to false in pipeline parallel mode.")
187-
from paddlenlp.transformers import AutoModelForCausalLMPipe
188178

189-
if not training_args.autotuner_benchmark:
190-
model = AutoModelForCausalLMPipe.from_pretrained(
191-
model_args.model_name_or_path,
192-
config=model_config,
193-
from_aistudio=model_args.from_aistudio,
194-
)
195-
else:
196-
# NOTE(gongenlei): new add autotuner_benchmark
197-
model = AutoModelForCausalLMPipe.from_config(model_config, dtype=dtype)
179+
model_class = AutoModelForCausalLMPipe
180+
181+
if not training_args.autotuner_benchmark:
182+
model = model_class.from_pretrained(
183+
model_args.model_name_or_path,
184+
config=model_config,
185+
from_aistudio=model_args.from_aistudio,
186+
)
198187
else:
199-
if not training_args.autotuner_benchmark:
200-
model = AutoModelForCausalLM.from_pretrained(
201-
model_args.model_name_or_path,
202-
config=model_config,
203-
from_aistudio=model_args.from_aistudio,
204-
)
205-
else:
206-
# NOTE(gongenlei): new add autotuner_benchmark
207-
model = AutoModelForCausalLM.from_config(model_config, dtype=dtype)
188+
# NOTE(gongenlei): new add autotuner_benchmark
189+
model = model_class.from_config(model_config, dtype=dtype)
190+
208191
if training_args.do_train and model_args.neftune:
209192
# Inspired by https://github.com/neelsjain/NEFTune
210193
if hasattr(model, "get_input_embeddings"):

llm/llama/tests/unified-ckpt-llama-500m/config.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
"num_attention_heads": 8,
1313
"num_hidden_layers": 8,
1414
"pad_token_id": 0,
15-
"paddlenlp_version": null,
1615
"rms_norm_eps": 1e-06,
1716
"vocab_size": 32000
1817
}

0 commit comments

Comments
 (0)