Skip to content

Commit c3ec984

Browse files
authored
[Bug Fix]Fix merge parameters in pp (#8239)
* update merge pp * update des * renew
1 parent 0790824 commit c3ec984

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed

llm/docs/finetune.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./
202202

203203
```
204204
python merge_tp_and_pp_params.py \
205-
--model_name_or_path ./checkpoints/llama_sft_ckpts/checkpoint-100
205+
--model_name_or_path ./checkpoints/llama_sft_ckpts/checkpoint-100 \
206+
--pp 2 --tp 4
206207
```
207208

208209
<summary>&emsp; 脚本参数介绍</summary><div>

llm/merge_tp_and_pp_params.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import importlib
1515
import os
16+
import re
1617

1718
import paddle
1819

@@ -27,9 +28,33 @@ def parse_arguments():
2728
parser = argparse.ArgumentParser()
2829
parser.add_argument("--model_name_or_path", default=None, required=True, help="The directory of model.")
2930
parser.add_argument("--device", type=str, default="gpu", help="Device")
31+
parser.add_argument("--pipeline_parallel_degree", "--pp", type=int, required=True, help="pp degree")
32+
parser.add_argument("--tensor_parallel_degree", "--tp", type=int, required=True, help="tp degree")
3033
return parser.parse_args()
3134

3235

36+
def validate_model_file(path: str, tp_degree: int, pp_degree: int) -> None:
37+
files = os.listdir(path)
38+
pattern = r"model_state\.tp0\d*_pp0\d*\.pdparams|model_state\.tp0\d*\.pdparams|model_state\.pp0\d*\.pdparams"
39+
if pp_degree == 0:
40+
target_files = [f"model_state.tp{tp:0>2d}.pdparams" for tp in range(tp_degree)]
41+
elif tp_degree == 0:
42+
target_files = [f"model_state.pp{pp:0>2d}.pdparams" for pp in range(pp_degree)]
43+
else:
44+
target_files = [
45+
f"model_state.tp{tp:0>2d}_pp{pp:0>2d}.pdparams" for tp in range(tp_degree) for pp in range(pp_degree)
46+
]
47+
48+
exist_required_files = []
49+
for file in files:
50+
if re.match(pattern, file):
51+
exist_required_files.append(file)
52+
53+
missing_files = set(target_files) - set(exist_required_files)
54+
if len(missing_files) > 0:
55+
raise FileNotFoundError(f"Please check your pp/tp degree, missing files {list(missing_files)}")
56+
57+
3358
def load_tp_params(tp_degree, path):
3459
tp_state_dict_list = []
3560
for tp in range(tp_degree):
@@ -102,23 +127,30 @@ def main():
102127
paddle.set_device(args.device)
103128
config = AutoConfig.from_pretrained(args.model_name_or_path)
104129
init_class = config["architectures"][0]
105-
import_class = importlib.import_module(f"paddlenlp.transformers.{MAPPING_NAMES[init_class[:-11]]}.modeling")
130+
if args.pipeline_parallel_degree > 1:
131+
# using pp
132+
import_class = importlib.import_module(f"paddlenlp.transformers.{MAPPING_NAMES[init_class[:-15]]}.modeling_pp")
133+
else:
134+
# tp only
135+
import_class = importlib.import_module(f"paddlenlp.transformers.{MAPPING_NAMES[init_class[:-11]]}.modeling")
106136
model_class = getattr(import_class, init_class)
107137

108-
if config.tensor_parallel_degree > 1:
109-
if config.pipeline_parallel_degree > 1:
138+
validate_model_file(args.model_name_or_path, args.tensor_parallel_degree, args.pipeline_parallel_degree)
139+
140+
if args.tensor_parallel_degree > 1:
141+
if args.pipeline_parallel_degree > 1:
110142
tp_state_dict_list = load_tp_and_pp_params(
111-
config.tensor_parallel_degree, config.pipeline_parallel_degree, args.model_name_or_path
143+
args.tensor_parallel_degree, args.pipeline_parallel_degree, args.model_name_or_path
112144
)
113145
else:
114-
tp_state_dict_list = load_tp_params(config.tensor_parallel_degree, args.model_name_or_path)
146+
tp_state_dict_list = load_tp_params(args.tensor_parallel_degree, args.model_name_or_path)
115147
state_dict_to_save = merge_tensor_parallel(
116148
model_class=model_class, state_dict_list=tp_state_dict_list, config=config
117149
)
118150
logger.info("Saving")
119151
paddle.save(state_dict_to_save, os.path.join(args.model_name_or_path, "model_state.pdparams"))
120-
elif config.pipeline_parallel_degree > 1:
121-
state_dict_to_save = load_pp_params(config.pipeline_parallel_degree, args.model_name_or_path)
152+
elif args.pipeline_parallel_degree > 1:
153+
state_dict_to_save = load_pp_params(args.pipeline_parallel_degree, args.model_name_or_path)
122154
logger.info("Saving")
123155
paddle.save(state_dict_to_save, os.path.join(args.model_name_or_path, "model_state.pdparams"))
124156
else:

0 commit comments

Comments
 (0)