Skip to content

Commit dc4c75a

Browse files
committed
mkdir unified_checkpoint directory
1 parent cbbc074 commit dc4c75a

File tree

7 files changed

+58
-282
lines changed

7 files changed

+58
-282
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,6 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
598598
if use_unified_checkpoint:
599599
self.unified_checkpoint_handler.load_unified_checkpoint(
600600
self.model,
601-
self.optimizer,
602601
resume_from_checkpoint,
603602
)
604603
logger.info(f"Loading model from {resume_from_checkpoint} using unified checkpoint.")
@@ -1241,7 +1240,6 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
12411240
if self.args.unified_checkpoint:
12421241
self.unified_checkpoint_handler.load_unified_checkpoint(
12431242
self.model,
1244-
self.optimizer,
12451243
self.state.best_model_checkpoint,
12461244
)
12471245
if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1:
@@ -1289,7 +1287,6 @@ def _load_best_model_from_peft_checkpoint(self):
12891287
if self.args.unified_checkpoint:
12901288
self.unified_checkpoint_handler.load_unified_checkpoint(
12911289
self.model,
1292-
self.optimizer,
12931290
self.state.best_model_checkpoint,
12941291
)
12951292
if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1:

paddlenlp/trainer/plugins/unified_checkpoint.py renamed to paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

Lines changed: 10 additions & 243 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@
3333

3434
from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
3535
from paddlenlp.trainer.argparser import strtobool
36-
from paddlenlp.trainer.trainer_utils import ShardingOption
37-
from paddlenlp.trainer.utils.helper import distributed_file, distributed_isfile
36+
from paddlenlp.trainer.utils.helper import distributed_isfile
3837
from paddlenlp.transformers.model_utils import (
3938
PretrainedModel,
4039
_add_variant,
@@ -67,7 +66,7 @@
6766
SAFE_WEIGHTS_NAME,
6867
)
6968
from paddlenlp.utils.log import logger
70-
from paddlenlp.utils.nested import flatten_list, nested_copy
69+
from paddlenlp.utils.nested import nested_copy
7170

7271
if is_safetensors_available():
7372
from safetensors.numpy import save_file as safe_save_file
@@ -77,6 +76,7 @@
7776
else:
7877
from paddlenlp.utils.safetensors import fast_load_file as load_file
7978

79+
from .check_unified_checkpoint import check_unified_checkpoint, check_unified_optimizer
8080
from .shared_memory_utils import (
8181
_read_state_dict_from_shm,
8282
_traverse_copy_to_shm,
@@ -108,13 +108,13 @@
108108
get_sharded_file_name,
109109
get_sharded_index,
110110
is_need_master_weight,
111+
is_sharding_split_param_mode,
111112
mapping_optimizer_tp_actions,
112113
merge_tensor_parallel_for_optimizer,
113114
merge_tensor_parallel_with_shard,
114115
reduce_master_weights_status,
115116
rename_shard_file,
116-
save_config,
117-
save_prefix_past_key_value,
117+
save_model_config,
118118
select_model_weight_index,
119119
update_master_weight_status,
120120
)
@@ -361,25 +361,8 @@ def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None)
361361
json.dump(sharded_index, f, indent=4)
362362

363363
if self.args.should_save:
364-
# Save prefix model past_key_values
365-
if isinstance(model_to_save, PrefixModelForCausalLM):
366-
save_prefix_past_key_value(model_to_save, save_directory)
367-
model_to_save.prefix_config.save_pretrained(save_directory)
368-
if isinstance(model_to_save, LoRAModel):
369-
model_to_save.lora_config.save_pretrained(save_directory)
370-
371-
# save the config
372-
config_to_save = save_config(model_to_save)
373-
# Attach architecture to the config
374-
if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM):
375-
config_to_save.architectures = [model_to_save.model.__class__.__name__]
376-
else:
377-
config_to_save.architectures = [model_to_save.__class__.__name__]
378-
if self.args.should_save:
379-
config_to_save.save_pretrained(save_directory)
380-
# save generation config
381-
if model_to_save.can_generate():
382-
model_to_save.generation_config.save_pretrained(save_directory)
364+
save_model_config(model_to_save, save_directory)
365+
383366
paddle.device.cuda.empty_cache()
384367

385368
if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and self.args.should_save:
@@ -391,7 +374,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None)
391374
}
392375
paddle.save(save_info, os.path.join(save_directory, ".saving_info"))
393376

394-
def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str):
377+
def load_unified_checkpoint(self, model, resume_from_checkpoint: str):
395378
"""Load potential model checkpoint
396379
397380
Args:
@@ -539,11 +522,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
539522
save_single_card_optimizer(model, optimizer, output_dir) # no need to save signal
540523
return
541524

542-
if (
543-
self.args.sharding_parallel_degree > 1
544-
and ShardingOption.SHARD_OP in self.args.sharding
545-
and "split_param" in self.args.sharding_parallel_config
546-
):
525+
if is_sharding_split_param_mode(self.args):
547526
optim_state_dict, master_weights = gather_splited_param_for_optimizer(optimizer)
548527
else:
549528
optim_state_dict = nested_copy(optimizer.state_dict())
@@ -867,11 +846,7 @@ def unified_checkpoint_into_shards(
867846

868847
def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False):
869848
# Special process with split param.
870-
if (
871-
args.sharding_parallel_degree > 1
872-
and ShardingOption.SHARD_OP in args.sharding
873-
and "split_param" in args.sharding_parallel_config
874-
):
849+
if is_sharding_split_param_mode(args):
875850
returned_optim_state_dict = load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint)
876851
return returned_optim_state_dict
877852

@@ -1118,211 +1093,3 @@ def unified_optimizer_into_shards(
11181093
(optim_state_dict, shard_optimizer_file, sharded_optim_index),
11191094
(master_weights, shard_master_weight_file, sharded_master_weight_index),
11201095
]
1121-
1122-
1123-
def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization=False):
1124-
index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False)
1125-
index_filename = os.path.join(resume_from_checkpoint, index_filename)
1126-
# Find index json file and distribute this file in global group.
1127-
if distributed_isfile(index_filename):
1128-
distributed_file(index_filename)
1129-
else:
1130-
raise Exception(
1131-
f"Sorry, we can not find {index_filename}. This file should be appear at least on one machine."
1132-
)
1133-
1134-
with open(index_filename, "r") as f:
1135-
index = json.loads(f.read())
1136-
all_weight_filenames = sorted(set(index["weight_map"].values()))
1137-
1138-
# Get existed weight file list on current machine.
1139-
existed_filelist = []
1140-
existed_files = []
1141-
for filename in os.listdir(resume_from_checkpoint):
1142-
if filename in all_weight_filenames:
1143-
existed_files.append(filename)
1144-
1145-
# Gather all the existed files in global group.
1146-
dist.all_gather_object(existed_filelist, existed_files)
1147-
flatten_existed_filelist = flatten_list(existed_filelist)
1148-
diff_filelist = list(set(all_weight_filenames).difference(set(flatten_existed_filelist)))
1149-
if len(diff_filelist) != 0:
1150-
raise Exception(f"Sorry, the weight file list on the machines is not complete!, missing {diff_filelist}")
1151-
1152-
# To decide whether to load the checkpoint locally, or need to dynamically send tensors across machines.
1153-
local_resume = True
1154-
if args.dataset_rank == 0 or args.use_expert_parallel:
1155-
hcg = fleet.get_hybrid_communicate_group()
1156-
tp_group = hcg.get_model_parallel_group()
1157-
pp_group = hcg.get_pipe_parallel_group()
1158-
dp_group = hcg.get_data_parallel_group()
1159-
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
1160-
1161-
need_files = set()
1162-
state_dict = get_expected_state_dict(model)
1163-
for key in state_dict.keys():
1164-
filename = index["weight_map"][key]
1165-
# When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0.
1166-
if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False):
1167-
continue
1168-
need_files.add(filename)
1169-
diff_filelist = list(need_files.difference(set(existed_files)))
1170-
num_diff = paddle.to_tensor([len(diff_filelist)])
1171-
if tp_group.nranks > 1:
1172-
dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=tp_group)
1173-
if pp_group.nranks > 1:
1174-
dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=pp_group)
1175-
if args.use_expert_parallel and dp_group.nranks > 1:
1176-
dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=dp_group)
1177-
if num_diff.item() == 0:
1178-
local_resume = True
1179-
else:
1180-
local_resume = False
1181-
local_resume = paddle.to_tensor([local_resume])
1182-
dist.all_reduce(local_resume, op=dist.ReduceOp.PROD)
1183-
local_resume = local_resume.item()
1184-
return local_resume
1185-
1186-
1187-
def check_unified_optimizer(args, model, optimizer, resume_from_checkpoint, safe_serialization=False):
1188-
if not safe_serialization:
1189-
index_filename, index_filename_master_weights = PADDLE_OPTIMIZER_INDEX_NAME, PADDLE_MASTER_WEIGHTS_INDEX_NAME
1190-
else:
1191-
index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME
1192-
index_filename = os.path.join(resume_from_checkpoint, index_filename)
1193-
index_filename_master_weights = os.path.join(resume_from_checkpoint, index_filename_master_weights)
1194-
1195-
# Find index json file and distribute the file in global group.
1196-
if distributed_isfile(index_filename):
1197-
distributed_file(index_filename)
1198-
else:
1199-
raise Exception(
1200-
f"Sorry, we can not find {index_filename}. This file should be appear at least on one machine."
1201-
)
1202-
1203-
with open(index_filename, "r") as f:
1204-
index = json.loads(f.read())
1205-
all_optimizer_filenames = sorted(set(index["weight_map"].values()))
1206-
1207-
has_master_weights = index["master_weights"]
1208-
# update has_master_weights and index_filename_master_weights
1209-
# 1. if the master weight exists, only has_master_weights is set True and loaded when needed
1210-
# 2. if master weight does not exist, convert model weight to master weight when needed
1211-
has_master_weights, index_filename_master_weights = update_master_weight_status(
1212-
args, optimizer, has_master_weights, safe_serialization
1213-
)
1214-
if has_master_weights:
1215-
index_filename_master_weights = os.path.join(resume_from_checkpoint, index_filename_master_weights)
1216-
if distributed_isfile(index_filename_master_weights):
1217-
distributed_file(index_filename_master_weights)
1218-
else:
1219-
raise Exception(
1220-
f"Sorry, we can not find {index_filename_master_weights}. This file should be appear at least on one machine."
1221-
)
1222-
with open(index_filename_master_weights, "r") as f:
1223-
index_mw = json.loads(f.read())
1224-
all_mw_filenames = sorted(set(index_mw["weight_map"].values()))
1225-
1226-
hcg = fleet.get_hybrid_communicate_group()
1227-
tp_group = hcg.get_model_parallel_group()
1228-
pp_group = hcg.get_pipe_parallel_group()
1229-
dp_group = hcg.get_data_parallel_group()
1230-
sharding_group = hcg.get_sharding_parallel_group()
1231-
sharding_rank = sharding_group.rank
1232-
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
1233-
struct2static_name_mappings = {k: v.name for k, v in model.state_dict().items()}
1234-
1235-
if (
1236-
args.sharding_parallel_degree > 1
1237-
and ShardingOption.SHARD_OP in args.sharding
1238-
and "split_param" in args.sharding_parallel_config
1239-
):
1240-
# We do not check optimizer files completion for split_param, since it is very complicated. Directly support local resume.
1241-
logger.warning("We only support local resume for split_param mode, do not support dynamically loading.")
1242-
return True
1243-
1244-
if sharding_group.nranks > 1:
1245-
param2rank = optimizer._param2rank
1246-
1247-
def check_complete(all_filenames):
1248-
# Check whether the checkpoint files on machines are complete. If not complete, raise Exception.
1249-
existed_filelist = []
1250-
existed_files = []
1251-
for filename in os.listdir(resume_from_checkpoint):
1252-
if filename in all_filenames:
1253-
existed_files.append(filename)
1254-
1255-
dist.all_gather_object(existed_filelist, existed_files)
1256-
flatten_existed_filelist = flatten_list(existed_filelist)
1257-
diff_filelist = list(set(all_filenames).difference(set(flatten_existed_filelist)))
1258-
if len(diff_filelist) != 0:
1259-
raise Exception(
1260-
f"Sorry, the optimizer file list on `data_parallel_rank==0` machines is not complete!, missing {diff_filelist}"
1261-
)
1262-
return existed_files
1263-
1264-
def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, typename_set=None):
1265-
# To decide whether to load the checkpoint locally, or need to dynamically distribute the checkpoint.
1266-
local_resume = True
1267-
if args.data_parallel_rank == 0 or args.use_expert_parallel:
1268-
need_files = set()
1269-
state_dict = get_expected_state_dict(model)
1270-
1271-
for key in state_dict.keys():
1272-
if sharding_group.nranks > 1:
1273-
static_name = struct2static_name_mappings.get(key, None)
1274-
param_rank = param2rank.get(static_name, None)
1275-
if param_rank != sharding_rank:
1276-
continue
1277-
1278-
# When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0.
1279-
if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False):
1280-
continue
1281-
1282-
if is_master_weights and state_dict[key].dtype == core.VarDesc.VarType.FP32:
1283-
continue
1284-
1285-
if not is_master_weights:
1286-
for type_name in typename_set:
1287-
type_key = key + "/" + type_name
1288-
filename = weight_map[type_key]
1289-
need_files.add(filename)
1290-
else:
1291-
filename = weight_map[key]
1292-
need_files.add(filename)
1293-
1294-
diff_filelist = list(need_files.difference(set(existed_files)))
1295-
num_diff = paddle.to_tensor([len(diff_filelist)])
1296-
if tp_group.nranks > 1:
1297-
dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=tp_group)
1298-
if pp_group.nranks > 1:
1299-
dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=pp_group)
1300-
if sharding_group.nranks > 1:
1301-
dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=sharding_group)
1302-
if args.use_expert_parallel and dp_group.nranks > 1:
1303-
dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=dp_group)
1304-
1305-
if num_diff.item() == 0:
1306-
local_resume = True
1307-
else:
1308-
local_resume = False
1309-
local_resume = paddle.to_tensor([local_resume])
1310-
dist.all_reduce(local_resume, op=dist.ReduceOp.PROD)
1311-
return local_resume.item()
1312-
1313-
# check whether the optimizer checkpoint files are complete.
1314-
existed_files = check_complete(all_optimizer_filenames)
1315-
if has_master_weights:
1316-
existed_files_mw = check_complete(all_mw_filenames)
1317-
# get optimizer's param type name, like moment1_0.
1318-
typename_set = set()
1319-
for key in index["weight_map"].keys():
1320-
_, typename = key.split("/")
1321-
typename_set.add(typename)
1322-
local_resume = check_dynamic_load(
1323-
args, index["weight_map"], existed_files, is_master_weights=False, typename_set=typename_set
1324-
)
1325-
local_resume_rw = True
1326-
if has_master_weights:
1327-
local_resume_rw = check_dynamic_load(args, index_mw["weight_map"], existed_files_mw, is_master_weights=True)
1328-
return local_resume & local_resume_rw

paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py renamed to paddlenlp/trainer/unified_checkpoint/unified_checkpoint_sharding_v2.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@
3838
)
3939

4040

41-
def distributed_send_recv_splited_param(
41+
def merge_splited_param(
4242
state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False
4343
):
44+
"""Merge the splited param in sharding group."""
4445
global_rank = dist.get_rank()
4546
for key in list(state_dict.keys()):
4647
if state_dict[key].numel().item() == 1: # for example: beta1, beta2
@@ -144,13 +145,9 @@ def gather_splited_param_for_optimizer(optimizer):
144145
recv_table[key] = sharding_ranklist[0][0] # which sharding_rank to recv the splited tensor
145146
send_table[key] = [(rank, begin, end) for rank, begin, end in sharding_ranklist]
146147

147-
distributed_send_recv_splited_param(
148-
optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False
149-
)
148+
merge_splited_param(optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False)
150149
if master_weights is not None:
151-
distributed_send_recv_splited_param(
152-
master_weights, partial_tensor_list, param_shape_info, send_table, recv_table, True
153-
)
150+
merge_splited_param(master_weights, partial_tensor_list, param_shape_info, send_table, recv_table, True)
154151
return optim_state_dict, master_weights
155152

156153

0 commit comments

Comments
 (0)