1
- # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1
+ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2
2
#
3
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
4
# you may not use this file except in compliance with the License.
13
13
# limitations under the License.
14
14
import importlib
15
15
import os
16
+ import re
16
17
17
18
import paddle
18
19
@@ -27,9 +28,33 @@ def parse_arguments():
27
28
parser = argparse .ArgumentParser ()
28
29
parser .add_argument ("--model_name_or_path" , default = None , required = True , help = "The directory of model." )
29
30
parser .add_argument ("--device" , type = str , default = "gpu" , help = "Device" )
31
+ parser .add_argument ("--pipeline_parallel_degree" , "--pp" , type = int , required = True , help = "pp degree" )
32
+ parser .add_argument ("--tensor_parallel_degree" , "--tp" , type = int , required = True , help = "tp degree" )
30
33
return parser .parse_args ()
31
34
32
35
36
+ def validate_model_file (path : str , tp_degree : int , pp_degree : int ) -> None :
37
+ files = os .listdir (path )
38
+ pattern = r"model_state\.tp0\d*_pp0\d*\.pdparams|model_state\.tp0\d*\.pdparams|model_state\.pp0\d*\.pdparams"
39
+ if pp_degree == 0 :
40
+ target_files = [f"model_state.tp{ tp :0>2d} .pdparams" for tp in range (tp_degree )]
41
+ elif tp_degree == 0 :
42
+ target_files = [f"model_state.pp{ pp :0>2d} .pdparams" for pp in range (pp_degree )]
43
+ else :
44
+ target_files = [
45
+ f"model_state.tp{ tp :0>2d} _pp{ pp :0>2d} .pdparams" for tp in range (tp_degree ) for pp in range (pp_degree )
46
+ ]
47
+
48
+ exist_required_files = []
49
+ for file in files :
50
+ if re .match (pattern , file ):
51
+ exist_required_files .append (file )
52
+
53
+ missing_files = set (target_files ) - set (exist_required_files )
54
+ if len (missing_files ) > 0 :
55
+ raise FileNotFoundError (f"Please check your pp/tp degree, missing files { list (missing_files )} " )
56
+
57
+
33
58
def load_tp_params (tp_degree , path ):
34
59
tp_state_dict_list = []
35
60
for tp in range (tp_degree ):
@@ -102,23 +127,30 @@ def main():
102
127
paddle .set_device (args .device )
103
128
config = AutoConfig .from_pretrained (args .model_name_or_path )
104
129
init_class = config ["architectures" ][0 ]
105
- import_class = importlib .import_module (f"paddlenlp.transformers.{ MAPPING_NAMES [init_class [:- 11 ]]} .modeling" )
130
+ if args .pipeline_parallel_degree > 1 :
131
+ # using pp
132
+ import_class = importlib .import_module (f"paddlenlp.transformers.{ MAPPING_NAMES [init_class [:- 15 ]]} .modeling_pp" )
133
+ else :
134
+ # tp only
135
+ import_class = importlib .import_module (f"paddlenlp.transformers.{ MAPPING_NAMES [init_class [:- 11 ]]} .modeling" )
106
136
model_class = getattr (import_class , init_class )
107
137
108
- if config .tensor_parallel_degree > 1 :
109
- if config .pipeline_parallel_degree > 1 :
138
+ validate_model_file (args .model_name_or_path , args .tensor_parallel_degree , args .pipeline_parallel_degree )
139
+
140
+ if args .tensor_parallel_degree > 1 :
141
+ if args .pipeline_parallel_degree > 1 :
110
142
tp_state_dict_list = load_tp_and_pp_params (
111
- config .tensor_parallel_degree , config .pipeline_parallel_degree , args .model_name_or_path
143
+ args .tensor_parallel_degree , args .pipeline_parallel_degree , args .model_name_or_path
112
144
)
113
145
else :
114
- tp_state_dict_list = load_tp_params (config .tensor_parallel_degree , args .model_name_or_path )
146
+ tp_state_dict_list = load_tp_params (args .tensor_parallel_degree , args .model_name_or_path )
115
147
state_dict_to_save = merge_tensor_parallel (
116
148
model_class = model_class , state_dict_list = tp_state_dict_list , config = config
117
149
)
118
150
logger .info ("Saving" )
119
151
paddle .save (state_dict_to_save , os .path .join (args .model_name_or_path , "model_state.pdparams" ))
120
- elif config .pipeline_parallel_degree > 1 :
121
- state_dict_to_save = load_pp_params (config .pipeline_parallel_degree , args .model_name_or_path )
152
+ elif args .pipeline_parallel_degree > 1 :
153
+ state_dict_to_save = load_pp_params (args .pipeline_parallel_degree , args .model_name_or_path )
122
154
logger .info ("Saving" )
123
155
paddle .save (state_dict_to_save , os .path .join (args .model_name_or_path , "model_state.pdparams" ))
124
156
else :
0 commit comments