Skip to content

[LLM] Support prefix tuning and lora for qwen2 #8601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions llm/qwen/lora_argument_qwen2_7b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"model_name_or_path": "Qwen/Qwen2-7B",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/qwen2_7b__lora_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps": 16,
"num_train_epochs": 3,
"learning_rate": 3e-04,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 2048,
"max_length": 4096,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"lora": true,
"zero_padding": false,
"use_flash_attention": false
}
41 changes: 41 additions & 0 deletions llm/qwen/pretrain-qwen1.5_7b-tp2sd4_stage2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"model_name_or_path": "Qwen/Qwen1.5-7B",
"tokenizer_name_or_path": "Qwen/Qwen1.5-7B",
"input_dir": "./data",
"output_dir": "./checkpoints/qwen1.5_7b_pretrain_ckpts",
"per_device_train_batch_size": 2,
"gradient_accumulation_steps": 8,
"per_device_eval_batch_size": 2,
"tensor_parallel_degree": 2,
"pipeline_parallel_degree": 1,
"sharding_parallel_degree": 4,
"sharding": "stage2",
"virtual_pp_degree": 1,
"sequence_parallel": 0,
"use_flash_attention": true,
"use_fused_rms_norm": true,
"use_fused_rope": true,
"max_seq_length": 4096,
"learning_rate": 3e-05,
"min_learning_rate": 3e-06,
"warmup_steps": 30,
"logging_steps": 1,
"max_steps": 10000,
"save_steps": 5000,
"eval_steps": 1000,
"weight_decay": 0.01,
"bf16": true,
"fp16_opt_level": "O2",
"warmup_ratio": 0.01,
"max_grad_norm": 1.0,
"dataloader_num_workers": 1,
"continue_training": 1,
"do_train": true,
"do_eval": true,
"do_predict": true,
"disable_tqdm": true,
"recompute": true,
"distributed_dataloader": 1,
"recompute_granularity": "full",
"save_total_limit": 2
}
41 changes: 41 additions & 0 deletions llm/qwen/pretrain-qwen2_7b-tp2sd4_stage2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"model_name_or_path": "Qwen/Qwen2-7B",
"tokenizer_name_or_path": "Qwen/Qwen2-7B",
"input_dir": "./data",
"output_dir": "./checkpoints/qwen2_7b_pretrain_ckpts",
"per_device_train_batch_size": 2,
"gradient_accumulation_steps": 8,
"per_device_eval_batch_size": 2,
"tensor_parallel_degree": 2,
"pipeline_parallel_degree": 1,
"sharding_parallel_degree": 4,
"sharding": "stage2",
"virtual_pp_degree": 1,
"sequence_parallel": 0,
"use_flash_attention": true,
"use_fused_rms_norm": true,
"use_fused_rope": true,
"max_seq_length": 4096,
"learning_rate": 3e-05,
"min_learning_rate": 3e-06,
"warmup_steps": 30,
"logging_steps": 1,
"max_steps": 10000,
"save_steps": 5000,
"eval_steps": 1000,
"weight_decay": 0.01,
"bf16": true,
"fp16_opt_level": "O2",
"warmup_ratio": 0.01,
"max_grad_norm": 1.0,
"dataloader_num_workers": 1,
"continue_training": 1,
"do_train": true,
"do_eval": true,
"do_predict": true,
"disable_tqdm": true,
"recompute": false,
"distributed_dataloader": 1,
"recompute_granularity": "full",
"save_total_limit": 2
}
33 changes: 33 additions & 0 deletions llm/qwen/pt_argument_qwen2_7b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"model_name_or_path": "Qwen/Qwen2-7B",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/qwen2_7b_pt_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps": 16,
"num_train_epochs": 3,
"learning_rate": 3e-02,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 2048,
"max_length": 4096,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"prefix_tuning": true,
"zero_padding": false,
"use_flash_attention": false
}

31 changes: 31 additions & 0 deletions llm/qwen/sft_argument_qwen2_7b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"model_name_or_path": "Qwen/Qwen2-7B",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/qwen2-7b_sft_ckpts",
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-05,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 2048,
"max_length": 4096,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 4,
"pipeline_parallel_degree": 1,
"zero_padding": false,
"use_flash_attention": false
}
21 changes: 20 additions & 1 deletion llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ChatGLMv2Tokenizer,
LlamaForCausalLMPipe,
PretrainedConfig,
Qwen2ForCausalLMPipe,
)
from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer
from paddlenlp.utils.log import logger
Expand Down Expand Up @@ -67,7 +68,7 @@ def get_prefix_tuning_params(model):
num_hidden_layers = model.config.num_layers
hidden_size = model.config.hidden_size
postprocess_past_key_value = chatglm_postprocess_past_key_value
multi_query_group_num = model.config.multi_query_group_num
multi_query_group_num = model.config.multi_query_group_num # num_key_value_heads
elif model.base_model_prefix == "bloom":
from paddlenlp.peft.prefix import bloom_postprocess_past_key_value

Expand All @@ -92,6 +93,14 @@ def get_prefix_tuning_params(model):
hidden_size = model.config.hidden_size
postprocess_past_key_value = qwen_postprocess_past_key_value
multi_query_group_num = None
elif model.base_model_prefix == "qwen2":
from paddlenlp.peft.prefix import qwen_postprocess_past_key_value

num_attention_heads = model.config.num_attention_heads
num_hidden_layers = model.config.num_hidden_layers
hidden_size = model.config.hidden_size
postprocess_past_key_value = qwen_postprocess_past_key_value
multi_query_group_num = model.config.num_key_value_heads # num_key_value_heads
else:
raise ValueError(f"Unknown base_model_prefix: {model.base_model_prefix}. ")
return dict(
Expand Down Expand Up @@ -150,6 +159,16 @@ def get_lora_target_modules(model):
".*mlp.w2.*",
".*mlp.c_proj.*",
]
elif model.base_model_prefix == "qwen2" or isinstance(model, Qwen2ForCausalLMPipe):
target_modules = [
".*q_proj.*",
".*k_proj.*",
".*v_proj.*",
".*o_proj.*",
".*gate_proj.*",
".*down_proj.*",
".*up_proj.*",
]
elif model.base_model_prefix == "mixtral":
target_modules = [
".*q_proj.*",
Expand Down
11 changes: 11 additions & 0 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2356,6 +2356,17 @@
)
pass

# Note:
# 1. PipelineLayer will create parameters for each layer and
# call `_synchronize_shared_weights()` to synchronize the shared parameters.
# 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to
# synchronize the shared parameters.
# However, when state dict only contains the one piece of shared parameters, the shared parameters
# will be different from the original shared parameters.

if isinstance(model, PipelineLayer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

赞👍🏻

model._synchronize_shared_weights()

Check warning on line 2368 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2368

Added line #L2368 was not covered by tests

if paddle.in_dynamic_mode():
return model

Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/qwen2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@

from .configuration import *
from .modeling import *
from .modeling_pp import *
from .tokenizer import *
3 changes: 3 additions & 0 deletions paddlenlp/transformers/qwen2/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def __init__(
self.eos_token_id = eos_token_id

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
Loading
Loading