33
33
34
34
from paddlenlp .peft import LoRAModel , PrefixModelForCausalLM
35
35
from paddlenlp .trainer .argparser import strtobool
36
- from paddlenlp .trainer .trainer_utils import ShardingOption
37
- from paddlenlp .trainer .utils .helper import distributed_file , distributed_isfile
36
+ from paddlenlp .trainer .utils .helper import distributed_isfile
38
37
from paddlenlp .transformers .model_utils import (
39
38
PretrainedModel ,
40
39
_add_variant ,
67
66
SAFE_WEIGHTS_NAME ,
68
67
)
69
68
from paddlenlp .utils .log import logger
70
- from paddlenlp .utils .nested import flatten_list , nested_copy
69
+ from paddlenlp .utils .nested import nested_copy
71
70
72
71
if is_safetensors_available ():
73
72
from safetensors .numpy import save_file as safe_save_file
77
76
else :
78
77
from paddlenlp .utils .safetensors import fast_load_file as load_file
79
78
79
+ from .check_unified_checkpoint import check_unified_checkpoint , check_unified_optimizer
80
80
from .shared_memory_utils import (
81
81
_read_state_dict_from_shm ,
82
82
_traverse_copy_to_shm ,
108
108
get_sharded_file_name ,
109
109
get_sharded_index ,
110
110
is_need_master_weight ,
111
+ is_sharding_split_param_mode ,
111
112
mapping_optimizer_tp_actions ,
112
113
merge_tensor_parallel_for_optimizer ,
113
114
merge_tensor_parallel_with_shard ,
114
115
reduce_master_weights_status ,
115
116
rename_shard_file ,
116
- save_config ,
117
- save_prefix_past_key_value ,
117
+ save_model_config ,
118
118
select_model_weight_index ,
119
119
update_master_weight_status ,
120
120
)
@@ -361,25 +361,8 @@ def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None)
361
361
json .dump (sharded_index , f , indent = 4 )
362
362
363
363
if self .args .should_save :
364
- # Save prefix model past_key_values
365
- if isinstance (model_to_save , PrefixModelForCausalLM ):
366
- save_prefix_past_key_value (model_to_save , save_directory )
367
- model_to_save .prefix_config .save_pretrained (save_directory )
368
- if isinstance (model_to_save , LoRAModel ):
369
- model_to_save .lora_config .save_pretrained (save_directory )
370
-
371
- # save the config
372
- config_to_save = save_config (model_to_save )
373
- # Attach architecture to the config
374
- if isinstance (model_to_save , LoRAModel ) or isinstance (model_to_save , PrefixModelForCausalLM ):
375
- config_to_save .architectures = [model_to_save .model .__class__ .__name__ ]
376
- else :
377
- config_to_save .architectures = [model_to_save .__class__ .__name__ ]
378
- if self .args .should_save :
379
- config_to_save .save_pretrained (save_directory )
380
- # save generation config
381
- if model_to_save .can_generate ():
382
- model_to_save .generation_config .save_pretrained (save_directory )
364
+ save_model_config (model_to_save , save_directory )
365
+
383
366
paddle .device .cuda .empty_cache ()
384
367
385
368
if strtobool (os .getenv ("FLAG_LLM_PDC" , "False" )) and self .args .should_save :
@@ -391,7 +374,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None)
391
374
}
392
375
paddle .save (save_info , os .path .join (save_directory , ".saving_info" ))
393
376
394
- def load_unified_checkpoint (self , model , optimizer , resume_from_checkpoint : str ):
377
+ def load_unified_checkpoint (self , model , resume_from_checkpoint : str ):
395
378
"""Load potential model checkpoint
396
379
397
380
Args:
@@ -539,11 +522,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
539
522
save_single_card_optimizer (model , optimizer , output_dir ) # no need to save signal
540
523
return
541
524
542
- if (
543
- self .args .sharding_parallel_degree > 1
544
- and ShardingOption .SHARD_OP in self .args .sharding
545
- and "split_param" in self .args .sharding_parallel_config
546
- ):
525
+ if is_sharding_split_param_mode (self .args ):
547
526
optim_state_dict , master_weights = gather_splited_param_for_optimizer (optimizer )
548
527
else :
549
528
optim_state_dict = nested_copy (optimizer .state_dict ())
@@ -867,11 +846,7 @@ def unified_checkpoint_into_shards(
867
846
868
847
def load_unified_optimizer_locally (args , model , optimizer , resume_from_checkpoint , safe_serialization = False ):
869
848
# Special process with split param.
870
- if (
871
- args .sharding_parallel_degree > 1
872
- and ShardingOption .SHARD_OP in args .sharding
873
- and "split_param" in args .sharding_parallel_config
874
- ):
849
+ if is_sharding_split_param_mode (args ):
875
850
returned_optim_state_dict = load_unified_optimizer_split_param (model , optimizer , resume_from_checkpoint )
876
851
return returned_optim_state_dict
877
852
@@ -1118,211 +1093,3 @@ def unified_optimizer_into_shards(
1118
1093
(optim_state_dict , shard_optimizer_file , sharded_optim_index ),
1119
1094
(master_weights , shard_master_weight_file , sharded_master_weight_index ),
1120
1095
]
1121
-
1122
-
1123
- def check_unified_checkpoint (args , model , resume_from_checkpoint , safe_serialization = False ):
1124
- index_filename = select_model_weight_index (model , resume_from_checkpoint , safe_serialization , local = False )
1125
- index_filename = os .path .join (resume_from_checkpoint , index_filename )
1126
- # Find index json file and distribute this file in global group.
1127
- if distributed_isfile (index_filename ):
1128
- distributed_file (index_filename )
1129
- else :
1130
- raise Exception (
1131
- f"Sorry, we can not find { index_filename } . This file should be appear at least on one machine."
1132
- )
1133
-
1134
- with open (index_filename , "r" ) as f :
1135
- index = json .loads (f .read ())
1136
- all_weight_filenames = sorted (set (index ["weight_map" ].values ()))
1137
-
1138
- # Get existed weight file list on current machine.
1139
- existed_filelist = []
1140
- existed_files = []
1141
- for filename in os .listdir (resume_from_checkpoint ):
1142
- if filename in all_weight_filenames :
1143
- existed_files .append (filename )
1144
-
1145
- # Gather all the existed files in global group.
1146
- dist .all_gather_object (existed_filelist , existed_files )
1147
- flatten_existed_filelist = flatten_list (existed_filelist )
1148
- diff_filelist = list (set (all_weight_filenames ).difference (set (flatten_existed_filelist )))
1149
- if len (diff_filelist ) != 0 :
1150
- raise Exception (f"Sorry, the weight file list on the machines is not complete!, missing { diff_filelist } " )
1151
-
1152
- # To decide whether to load the checkpoint locally, or need to dynamically send tensors across machines.
1153
- local_resume = True
1154
- if args .dataset_rank == 0 or args .use_expert_parallel :
1155
- hcg = fleet .get_hybrid_communicate_group ()
1156
- tp_group = hcg .get_model_parallel_group ()
1157
- pp_group = hcg .get_pipe_parallel_group ()
1158
- dp_group = hcg .get_data_parallel_group ()
1159
- dp_rank = dp_group .rank if dp_group .nranks > 1 else 0
1160
-
1161
- need_files = set ()
1162
- state_dict = get_expected_state_dict (model )
1163
- for key in state_dict .keys ():
1164
- filename = index ["weight_map" ][key ]
1165
- # When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0.
1166
- if args .use_expert_parallel and dp_rank > 0 and not getattr (state_dict [key ], "no_sync" , False ):
1167
- continue
1168
- need_files .add (filename )
1169
- diff_filelist = list (need_files .difference (set (existed_files )))
1170
- num_diff = paddle .to_tensor ([len (diff_filelist )])
1171
- if tp_group .nranks > 1 :
1172
- dist .all_reduce (num_diff , op = dist .ReduceOp .MAX , group = tp_group )
1173
- if pp_group .nranks > 1 :
1174
- dist .all_reduce (num_diff , op = dist .ReduceOp .MAX , group = pp_group )
1175
- if args .use_expert_parallel and dp_group .nranks > 1 :
1176
- dist .all_reduce (num_diff , op = dist .ReduceOp .MAX , group = dp_group )
1177
- if num_diff .item () == 0 :
1178
- local_resume = True
1179
- else :
1180
- local_resume = False
1181
- local_resume = paddle .to_tensor ([local_resume ])
1182
- dist .all_reduce (local_resume , op = dist .ReduceOp .PROD )
1183
- local_resume = local_resume .item ()
1184
- return local_resume
1185
-
1186
-
1187
- def check_unified_optimizer (args , model , optimizer , resume_from_checkpoint , safe_serialization = False ):
1188
- if not safe_serialization :
1189
- index_filename , index_filename_master_weights = PADDLE_OPTIMIZER_INDEX_NAME , PADDLE_MASTER_WEIGHTS_INDEX_NAME
1190
- else :
1191
- index_filename , index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME , SAFE_MASTER_WEIGHTS_INDEX_NAME
1192
- index_filename = os .path .join (resume_from_checkpoint , index_filename )
1193
- index_filename_master_weights = os .path .join (resume_from_checkpoint , index_filename_master_weights )
1194
-
1195
- # Find index json file and distribute the file in global group.
1196
- if distributed_isfile (index_filename ):
1197
- distributed_file (index_filename )
1198
- else :
1199
- raise Exception (
1200
- f"Sorry, we can not find { index_filename } . This file should be appear at least on one machine."
1201
- )
1202
-
1203
- with open (index_filename , "r" ) as f :
1204
- index = json .loads (f .read ())
1205
- all_optimizer_filenames = sorted (set (index ["weight_map" ].values ()))
1206
-
1207
- has_master_weights = index ["master_weights" ]
1208
- # update has_master_weights and index_filename_master_weights
1209
- # 1. if the master weight exists, only has_master_weights is set True and loaded when needed
1210
- # 2. if master weight does not exist, convert model weight to master weight when needed
1211
- has_master_weights , index_filename_master_weights = update_master_weight_status (
1212
- args , optimizer , has_master_weights , safe_serialization
1213
- )
1214
- if has_master_weights :
1215
- index_filename_master_weights = os .path .join (resume_from_checkpoint , index_filename_master_weights )
1216
- if distributed_isfile (index_filename_master_weights ):
1217
- distributed_file (index_filename_master_weights )
1218
- else :
1219
- raise Exception (
1220
- f"Sorry, we can not find { index_filename_master_weights } . This file should be appear at least on one machine."
1221
- )
1222
- with open (index_filename_master_weights , "r" ) as f :
1223
- index_mw = json .loads (f .read ())
1224
- all_mw_filenames = sorted (set (index_mw ["weight_map" ].values ()))
1225
-
1226
- hcg = fleet .get_hybrid_communicate_group ()
1227
- tp_group = hcg .get_model_parallel_group ()
1228
- pp_group = hcg .get_pipe_parallel_group ()
1229
- dp_group = hcg .get_data_parallel_group ()
1230
- sharding_group = hcg .get_sharding_parallel_group ()
1231
- sharding_rank = sharding_group .rank
1232
- dp_rank = dp_group .rank if dp_group .nranks > 1 else 0
1233
- struct2static_name_mappings = {k : v .name for k , v in model .state_dict ().items ()}
1234
-
1235
- if (
1236
- args .sharding_parallel_degree > 1
1237
- and ShardingOption .SHARD_OP in args .sharding
1238
- and "split_param" in args .sharding_parallel_config
1239
- ):
1240
- # We do not check optimizer files completion for split_param, since it is very complicated. Directly support local resume.
1241
- logger .warning ("We only support local resume for split_param mode, do not support dynamically loading." )
1242
- return True
1243
-
1244
- if sharding_group .nranks > 1 :
1245
- param2rank = optimizer ._param2rank
1246
-
1247
- def check_complete (all_filenames ):
1248
- # Check whether the checkpoint files on machines are complete. If not complete, raise Exception.
1249
- existed_filelist = []
1250
- existed_files = []
1251
- for filename in os .listdir (resume_from_checkpoint ):
1252
- if filename in all_filenames :
1253
- existed_files .append (filename )
1254
-
1255
- dist .all_gather_object (existed_filelist , existed_files )
1256
- flatten_existed_filelist = flatten_list (existed_filelist )
1257
- diff_filelist = list (set (all_filenames ).difference (set (flatten_existed_filelist )))
1258
- if len (diff_filelist ) != 0 :
1259
- raise Exception (
1260
- f"Sorry, the optimizer file list on `data_parallel_rank==0` machines is not complete!, missing { diff_filelist } "
1261
- )
1262
- return existed_files
1263
-
1264
- def check_dynamic_load (args , weight_map , existed_files , is_master_weights = False , typename_set = None ):
1265
- # To decide whether to load the checkpoint locally, or need to dynamically distribute the checkpoint.
1266
- local_resume = True
1267
- if args .data_parallel_rank == 0 or args .use_expert_parallel :
1268
- need_files = set ()
1269
- state_dict = get_expected_state_dict (model )
1270
-
1271
- for key in state_dict .keys ():
1272
- if sharding_group .nranks > 1 :
1273
- static_name = struct2static_name_mappings .get (key , None )
1274
- param_rank = param2rank .get (static_name , None )
1275
- if param_rank != sharding_rank :
1276
- continue
1277
-
1278
- # When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0.
1279
- if args .use_expert_parallel and dp_rank > 0 and not getattr (state_dict [key ], "no_sync" , False ):
1280
- continue
1281
-
1282
- if is_master_weights and state_dict [key ].dtype == core .VarDesc .VarType .FP32 :
1283
- continue
1284
-
1285
- if not is_master_weights :
1286
- for type_name in typename_set :
1287
- type_key = key + "/" + type_name
1288
- filename = weight_map [type_key ]
1289
- need_files .add (filename )
1290
- else :
1291
- filename = weight_map [key ]
1292
- need_files .add (filename )
1293
-
1294
- diff_filelist = list (need_files .difference (set (existed_files )))
1295
- num_diff = paddle .to_tensor ([len (diff_filelist )])
1296
- if tp_group .nranks > 1 :
1297
- dist .all_reduce (num_diff , op = dist .ReduceOp .MAX , group = tp_group )
1298
- if pp_group .nranks > 1 :
1299
- dist .all_reduce (num_diff , op = dist .ReduceOp .MAX , group = pp_group )
1300
- if sharding_group .nranks > 1 :
1301
- dist .all_reduce (num_diff , op = dist .ReduceOp .MAX , group = sharding_group )
1302
- if args .use_expert_parallel and dp_group .nranks > 1 :
1303
- dist .all_reduce (num_diff , op = dist .ReduceOp .MAX , group = dp_group )
1304
-
1305
- if num_diff .item () == 0 :
1306
- local_resume = True
1307
- else :
1308
- local_resume = False
1309
- local_resume = paddle .to_tensor ([local_resume ])
1310
- dist .all_reduce (local_resume , op = dist .ReduceOp .PROD )
1311
- return local_resume .item ()
1312
-
1313
- # check whether the optimizer checkpoint files are complete.
1314
- existed_files = check_complete (all_optimizer_filenames )
1315
- if has_master_weights :
1316
- existed_files_mw = check_complete (all_mw_filenames )
1317
- # get optimizer's param type name, like moment1_0.
1318
- typename_set = set ()
1319
- for key in index ["weight_map" ].keys ():
1320
- _ , typename = key .split ("/" )
1321
- typename_set .add (typename )
1322
- local_resume = check_dynamic_load (
1323
- args , index ["weight_map" ], existed_files , is_master_weights = False , typename_set = typename_set
1324
- )
1325
- local_resume_rw = True
1326
- if has_master_weights :
1327
- local_resume_rw = check_dynamic_load (args , index_mw ["weight_map" ], existed_files_mw , is_master_weights = True )
1328
- return local_resume & local_resume_rw
0 commit comments