Skip to content

Commit da8b9ac

Browse files
authored
[LLM] Support prefix tuning and lora for qwen2 (#8601)
* add unittest for qwen2 * update for tie_word_embeddings * update qwen2 * update for tokenizer set attr to null * support prefix training * fix llm unittest * fix pipeline and sequence parallel * add pretrain configs * add lora, prefix tuning, sft config * fix pp for lora and recompute
1 parent 6bfca91 commit da8b9ac

22 files changed

+1007
-76
lines changed

llm/qwen/lora_argument_qwen2_7b.json

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
{
2+
"model_name_or_path": "Qwen/Qwen2-7B",
3+
"dataset_name_or_path": "./data",
4+
"output_dir": "./checkpoints/qwen2_7b__lora_ckpts",
5+
"per_device_train_batch_size": 4,
6+
"gradient_accumulation_steps": 4,
7+
"per_device_eval_batch_size": 8,
8+
"eval_accumulation_steps": 16,
9+
"num_train_epochs": 3,
10+
"learning_rate": 3e-04,
11+
"warmup_steps": 30,
12+
"logging_steps": 1,
13+
"evaluation_strategy": "epoch",
14+
"save_strategy": "epoch",
15+
"src_length": 2048,
16+
"max_length": 4096,
17+
"bf16": true,
18+
"fp16_opt_level": "O2",
19+
"do_train": true,
20+
"do_eval": true,
21+
"disable_tqdm": true,
22+
"load_best_model_at_end": true,
23+
"eval_with_do_generation": false,
24+
"metric_for_best_model": "accuracy",
25+
"recompute": true,
26+
"save_total_limit": 1,
27+
"tensor_parallel_degree": 1,
28+
"pipeline_parallel_degree": 1,
29+
"lora": true,
30+
"zero_padding": false,
31+
"use_flash_attention": false
32+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
{
2+
"model_name_or_path": "Qwen/Qwen1.5-7B",
3+
"tokenizer_name_or_path": "Qwen/Qwen1.5-7B",
4+
"input_dir": "./data",
5+
"output_dir": "./checkpoints/qwen1.5_7b_pretrain_ckpts",
6+
"per_device_train_batch_size": 2,
7+
"gradient_accumulation_steps": 8,
8+
"per_device_eval_batch_size": 2,
9+
"tensor_parallel_degree": 2,
10+
"pipeline_parallel_degree": 1,
11+
"sharding_parallel_degree": 4,
12+
"sharding": "stage2",
13+
"virtual_pp_degree": 1,
14+
"sequence_parallel": 0,
15+
"use_flash_attention": true,
16+
"use_fused_rms_norm": true,
17+
"use_fused_rope": true,
18+
"max_seq_length": 4096,
19+
"learning_rate": 3e-05,
20+
"min_learning_rate": 3e-06,
21+
"warmup_steps": 30,
22+
"logging_steps": 1,
23+
"max_steps": 10000,
24+
"save_steps": 5000,
25+
"eval_steps": 1000,
26+
"weight_decay": 0.01,
27+
"bf16": true,
28+
"fp16_opt_level": "O2",
29+
"warmup_ratio": 0.01,
30+
"max_grad_norm": 1.0,
31+
"dataloader_num_workers": 1,
32+
"continue_training": 1,
33+
"do_train": true,
34+
"do_eval": true,
35+
"do_predict": true,
36+
"disable_tqdm": true,
37+
"recompute": true,
38+
"distributed_dataloader": 1,
39+
"recompute_granularity": "full",
40+
"save_total_limit": 2
41+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
{
2+
"model_name_or_path": "Qwen/Qwen2-7B",
3+
"tokenizer_name_or_path": "Qwen/Qwen2-7B",
4+
"input_dir": "./data",
5+
"output_dir": "./checkpoints/qwen2_7b_pretrain_ckpts",
6+
"per_device_train_batch_size": 2,
7+
"gradient_accumulation_steps": 8,
8+
"per_device_eval_batch_size": 2,
9+
"tensor_parallel_degree": 2,
10+
"pipeline_parallel_degree": 1,
11+
"sharding_parallel_degree": 4,
12+
"sharding": "stage2",
13+
"virtual_pp_degree": 1,
14+
"sequence_parallel": 0,
15+
"use_flash_attention": true,
16+
"use_fused_rms_norm": true,
17+
"use_fused_rope": true,
18+
"max_seq_length": 4096,
19+
"learning_rate": 3e-05,
20+
"min_learning_rate": 3e-06,
21+
"warmup_steps": 30,
22+
"logging_steps": 1,
23+
"max_steps": 10000,
24+
"save_steps": 5000,
25+
"eval_steps": 1000,
26+
"weight_decay": 0.01,
27+
"bf16": true,
28+
"fp16_opt_level": "O2",
29+
"warmup_ratio": 0.01,
30+
"max_grad_norm": 1.0,
31+
"dataloader_num_workers": 1,
32+
"continue_training": 1,
33+
"do_train": true,
34+
"do_eval": true,
35+
"do_predict": true,
36+
"disable_tqdm": true,
37+
"recompute": false,
38+
"distributed_dataloader": 1,
39+
"recompute_granularity": "full",
40+
"save_total_limit": 2
41+
}

llm/qwen/pt_argument_qwen2_7b.json

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"model_name_or_path": "Qwen/Qwen2-7B",
3+
"dataset_name_or_path": "./data",
4+
"output_dir": "./checkpoints/qwen2_7b_pt_ckpts",
5+
"per_device_train_batch_size": 4,
6+
"gradient_accumulation_steps": 4,
7+
"per_device_eval_batch_size": 8,
8+
"eval_accumulation_steps": 16,
9+
"num_train_epochs": 3,
10+
"learning_rate": 3e-02,
11+
"warmup_steps": 30,
12+
"logging_steps": 1,
13+
"evaluation_strategy": "epoch",
14+
"save_strategy": "epoch",
15+
"src_length": 2048,
16+
"max_length": 4096,
17+
"bf16": true,
18+
"fp16_opt_level": "O2",
19+
"do_train": true,
20+
"do_eval": true,
21+
"disable_tqdm": true,
22+
"load_best_model_at_end": true,
23+
"eval_with_do_generation": false,
24+
"metric_for_best_model": "accuracy",
25+
"recompute": true,
26+
"save_total_limit": 1,
27+
"tensor_parallel_degree": 1,
28+
"pipeline_parallel_degree": 1,
29+
"prefix_tuning": true,
30+
"zero_padding": false,
31+
"use_flash_attention": false
32+
}
33+

llm/qwen/sft_argument_qwen2_7b.json

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"model_name_or_path": "Qwen/Qwen2-7B",
3+
"dataset_name_or_path": "./data",
4+
"output_dir": "./checkpoints/qwen2-7b_sft_ckpts",
5+
"per_device_train_batch_size": 1,
6+
"gradient_accumulation_steps": 4,
7+
"per_device_eval_batch_size": 8,
8+
"eval_accumulation_steps":16,
9+
"num_train_epochs": 3,
10+
"learning_rate": 3e-05,
11+
"warmup_steps": 30,
12+
"logging_steps": 1,
13+
"evaluation_strategy": "epoch",
14+
"save_strategy": "epoch",
15+
"src_length": 2048,
16+
"max_length": 4096,
17+
"bf16": true,
18+
"fp16_opt_level": "O2",
19+
"do_train": true,
20+
"do_eval": true,
21+
"disable_tqdm": true,
22+
"load_best_model_at_end": true,
23+
"eval_with_do_generation": false,
24+
"metric_for_best_model": "accuracy",
25+
"recompute": true,
26+
"save_total_limit": 1,
27+
"tensor_parallel_degree": 4,
28+
"pipeline_parallel_degree": 1,
29+
"zero_padding": false,
30+
"use_flash_attention": false
31+
}

llm/utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
ChatGLMv2Tokenizer,
3636
LlamaForCausalLMPipe,
3737
PretrainedConfig,
38+
Qwen2ForCausalLMPipe,
3839
)
3940
from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer
4041
from paddlenlp.utils.log import logger
@@ -67,7 +68,7 @@ def get_prefix_tuning_params(model):
6768
num_hidden_layers = model.config.num_layers
6869
hidden_size = model.config.hidden_size
6970
postprocess_past_key_value = chatglm_postprocess_past_key_value
70-
multi_query_group_num = model.config.multi_query_group_num
71+
multi_query_group_num = model.config.multi_query_group_num # num_key_value_heads
7172
elif model.base_model_prefix == "bloom":
7273
from paddlenlp.peft.prefix import bloom_postprocess_past_key_value
7374

@@ -92,6 +93,14 @@ def get_prefix_tuning_params(model):
9293
hidden_size = model.config.hidden_size
9394
postprocess_past_key_value = qwen_postprocess_past_key_value
9495
multi_query_group_num = None
96+
elif model.base_model_prefix == "qwen2":
97+
from paddlenlp.peft.prefix import qwen_postprocess_past_key_value
98+
99+
num_attention_heads = model.config.num_attention_heads
100+
num_hidden_layers = model.config.num_hidden_layers
101+
hidden_size = model.config.hidden_size
102+
postprocess_past_key_value = qwen_postprocess_past_key_value
103+
multi_query_group_num = model.config.num_key_value_heads # num_key_value_heads
95104
else:
96105
raise ValueError(f"Unknown base_model_prefix: {model.base_model_prefix}. ")
97106
return dict(
@@ -150,6 +159,16 @@ def get_lora_target_modules(model):
150159
".*mlp.w2.*",
151160
".*mlp.c_proj.*",
152161
]
162+
elif model.base_model_prefix == "qwen2" or isinstance(model, Qwen2ForCausalLMPipe):
163+
target_modules = [
164+
".*q_proj.*",
165+
".*k_proj.*",
166+
".*v_proj.*",
167+
".*o_proj.*",
168+
".*gate_proj.*",
169+
".*down_proj.*",
170+
".*up_proj.*",
171+
]
153172
elif model.base_model_prefix == "mixtral":
154173
target_modules = [
155174
".*q_proj.*",

paddlenlp/transformers/model_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,6 +2348,17 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
23482348
)
23492349
pass
23502350

2351+
# Note:
2352+
# 1. PipelineLayer will create parameters for each layer and
2353+
# call `_synchronize_shared_weights()` to synchronize the shared parameters.
2354+
# 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to
2355+
# synchronize the shared parameters.
2356+
# However, when state dict only contains the one piece of shared parameters, the shared parameters
2357+
# will be different from the original shared parameters.
2358+
2359+
if isinstance(model, PipelineLayer):
2360+
model._synchronize_shared_weights()
2361+
23512362
if paddle.in_dynamic_mode():
23522363
return model
23532364

paddlenlp/transformers/qwen2/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515

1616
from .configuration import *
1717
from .modeling import *
18+
from .modeling_pp import *
1819
from .tokenizer import *

paddlenlp/transformers/qwen2/configuration.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ def __init__(
150150
self.eos_token_id = eos_token_id
151151

152152
super().__init__(
153+
pad_token_id=pad_token_id,
154+
bos_token_id=bos_token_id,
155+
eos_token_id=eos_token_id,
153156
tie_word_embeddings=tie_word_embeddings,
154157
**kwargs,
155158
)

0 commit comments

Comments
 (0)