Skip to content

Commit c0969e2

Browse files
authored
Fix split_param (#9817) (#9818)
1 parent 465ce1d commit c0969e2

File tree

3 files changed

+79
-28
lines changed

3 files changed

+79
-28
lines changed

paddlenlp/trainer/unified_checkpoint/load_local.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,6 @@ def _remove_unused_keys(
149149

150150

151151
def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False):
152-
# Special process with split param.
153-
if is_sharding_split_param_mode(args):
154-
returned_optim_state_dict = load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint)
155-
return returned_optim_state_dict
156-
157-
# init and get optimizer LR_Scheduler
158-
returned_optim_state_dict = nested_copy(optimizer.state_dict())
159-
160152
if not safe_serialization:
161153
index_filename, index_filename_master_weights = (
162154
PADDLE_OPTIMIZER_INDEX_NAME,
@@ -165,6 +157,23 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
165157
else:
166158
index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME
167159

160+
with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f:
161+
index = json.loads(f.read())
162+
163+
ckpt_quant_stage = "O0"
164+
if "ckpt_quant_stage" in index:
165+
ckpt_quant_stage = index["ckpt_quant_stage"]
166+
167+
# Special process with split param.
168+
if is_sharding_split_param_mode(args):
169+
returned_optim_state_dict = load_unified_optimizer_split_param(
170+
args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage
171+
)
172+
return returned_optim_state_dict
173+
174+
# init and get optimizer LR_Scheduler
175+
returned_optim_state_dict = nested_copy(optimizer.state_dict())
176+
168177
resolved_archive_file, sharded_metadata = get_optimizer_shard_files(
169178
optimizer_path=resume_from_checkpoint,
170179
index_filename=os.path.join(resume_from_checkpoint, index_filename),
@@ -184,13 +193,6 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
184193
if len(resolved_archive_file) > 1:
185194
resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")
186195

187-
with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f:
188-
index = json.loads(f.read())
189-
190-
ckpt_quant_stage = "O0"
191-
if "ckpt_quant_stage" in index:
192-
ckpt_quant_stage = index["ckpt_quant_stage"]
193-
194196
# update has_master_weights and index_filename_master_weights
195197
# 1. if the master weight exists, only has_master_weights is set True and loaded when needed
196198
# 2. if master weight does not exist, convert model weight to master weight when needed

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,25 @@
3636
get_expected_state_dict,
3737
get_optimizer_shard_files,
3838
mapping_optimizer_tp_actions,
39+
update_master_weight_status,
3940
)
4041

4142
__all__ = ["gather_splited_param_for_optimizer", "load_unified_optimizer_split_param"]
4243

4344

4445
def merge_splited_param(
45-
state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False
46+
state_dict,
47+
partial_tensor_list,
48+
param_shape_info,
49+
send_table,
50+
recv_table,
51+
is_master_weights=False,
52+
ckpt_quant_stage="O0",
4653
):
4754
"""Merge the splited param in sharding group."""
4855
global_rank = dist.get_rank()
4956
for key in list(state_dict.keys()):
50-
if state_dict[key].numel().item() == 1: # for example: beta1, beta2
57+
if int(state_dict[key].numel()) == 1: # for example: beta1, beta2
5158
continue
5259

5360
static_name = key if is_master_weights else generate_base_static_name(key)[0]
@@ -89,10 +96,21 @@ def merge_splited_param(
8996
)
9097
dist.stream.send(tensor, dst=recv_rank)
9198
state_dict.pop(key)
99+
100+
if ckpt_quant_stage != "O0":
101+
for key in list(state_dict.keys()):
102+
if int(state_dict[key].numel()) == 1: # for example: beta1, beta2
103+
static_name = key if is_master_weights else generate_base_static_name(key)[0]
104+
if static_name in partial_tensor_list:
105+
recv_rank = recv_table[static_name]
106+
send_info = send_table[static_name]
107+
if global_rank != recv_rank:
108+
state_dict.pop(key)
109+
92110
return state_dict
93111

94112

95-
def gather_splited_param_for_optimizer(optimizer):
113+
def gather_splited_param_for_optimizer(optimizer, ckpt_quant_stage="O0"):
96114
hcg = fleet.get_hybrid_communicate_group()
97115
sharding_group = hcg.get_sharding_parallel_group()
98116
global_rank = dist.get_rank()
@@ -127,7 +145,7 @@ def gather_splited_param_for_optimizer(optimizer):
127145
for key in list(optim_state_dict.keys()):
128146
static_name, _ = generate_base_static_name(key)
129147
if static_name in param_slice_info.keys():
130-
if optim_state_dict[key].numel().item() == 1: # for example: beta1, beta2
148+
if int(optim_state_dict[key].numel()) == 1: # for example: beta1, beta2
131149
continue
132150
begin, end = param_slice_info[static_name]
133151
shape, numel, _, _ = param_shape_info[static_name]
@@ -149,13 +167,15 @@ def gather_splited_param_for_optimizer(optimizer):
149167
recv_table[key] = sharding_ranklist[0][0] # which sharding_rank to recv the splited tensor
150168
send_table[key] = [(rank, begin, end) for rank, begin, end in sharding_ranklist]
151169

152-
merge_splited_param(optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False)
170+
merge_splited_param(
171+
optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False, ckpt_quant_stage
172+
)
153173
if master_weights is not None:
154174
merge_splited_param(master_weights, partial_tensor_list, param_shape_info, send_table, recv_table, True)
155175
return optim_state_dict, master_weights
156176

157177

158-
def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint):
178+
def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
159179
returned_optim_state_dict = nested_copy(optimizer.state_dict())
160180

161181
index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME
@@ -208,6 +228,10 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
208228
if len(resolved_archive_file) > 1:
209229
resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")
210230

231+
has_master_weights, index_filename_master_weights = update_master_weight_status(
232+
args, optimizer, has_master_weights, safe_serialization=True
233+
)
234+
211235
if has_master_weights:
212236
returned_optim_state_dict["master_weights"] = {}
213237
resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files(
@@ -217,7 +241,9 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
217241
if len(resolved_archive_file_mw) > 1:
218242
resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards")
219243

220-
def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False):
244+
def load_resolved_archive_file(
245+
resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False, ckpt_quant_stage="O0"
246+
):
221247
returned_state_dict = {}
222248

223249
if model.config.tensor_parallel_degree > 1:
@@ -232,24 +258,38 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
232258
if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]):
233259
continue
234260
if model.config.tensor_parallel_degree > 1:
235-
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="cpu")
261+
state_dict = load_state_dict(
262+
shard_file,
263+
tp_actions,
264+
expected_keys,
265+
device="cpu",
266+
ckpt_quant_stage=ckpt_quant_stage,
267+
)
236268
else:
237-
state_dict = load_state_dict(shard_file, None, expected_keys, device="cpu")
269+
state_dict = load_state_dict(
270+
shard_file,
271+
None,
272+
expected_keys,
273+
device="cpu",
274+
ckpt_quant_stage=ckpt_quant_stage,
275+
)
238276
returned_state_dict.update(state_dict)
239277
del state_dict
240278
gc.collect()
241279

242280
return returned_state_dict
243281

244282
# get tp params
245-
state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys_optim)
283+
state_dict_optim = load_resolved_archive_file(
284+
resolved_archive_file, sharded_metadata, expected_keys_optim, ckpt_quant_stage=ckpt_quant_stage
285+
)
246286

247287
# need to split param for different sharding rank, maybe need to deal with oom issue.
248288
for key in list(state_dict_optim.keys()):
249289
key_name = key.split("/")
250290
static_name = struct2static_name_mappings.get(key_name[0], None)
251291

252-
if state_dict_optim[key].numel().item() > 1:
292+
if int(state_dict_optim[key].numel()) > 1:
253293
begin, end = param_slice_info[static_name]
254294
shape, numel, index, padded_size = param_shape_info[static_name]
255295
state_dict_optim[key] = state_dict_optim[key].reshape([-1])
@@ -284,7 +324,7 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
284324

285325
for key in list(state_dict_master_weight.keys()):
286326
static_name = struct2static_name_mappings.get(key, None)
287-
if state_dict_master_weight[key].numel().item() > 1:
327+
if int(state_dict_master_weight[key].numel()) > 1:
288328
begin, end = param_slice_info[static_name]
289329
shape, numel, index, padded_size = param_shape_info[static_name]
290330
state_dict_master_weight[key] = state_dict_master_weight[key].reshape([-1])
@@ -303,6 +343,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
303343
paddle.framework._current_expected_place(), False
304344
)
305345
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key)
346+
347+
# master weight cast (only in remove_master_weight)
348+
if returned_optim_state_dict["master_weights"][static_name].dtype != paddle.float32:
349+
returned_optim_state_dict["master_weights"][static_name] = paddle.cast(
350+
returned_optim_state_dict["master_weights"][static_name], dtype=paddle.float32
351+
)
352+
306353
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])
307354

308355
return returned_optim_state_dict

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,9 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
344344
return
345345

346346
if is_sharding_split_param_mode(self.args):
347-
optim_state_dict, master_weights = gather_splited_param_for_optimizer(optimizer)
347+
optim_state_dict, master_weights = gather_splited_param_for_optimizer(
348+
optimizer, self.args.ckpt_quant_stage if "quant_reach_limit" not in infohub else "O0"
349+
)
348350
else:
349351
optim_state_dict = nested_copy(optimizer.state_dict())
350352
master_weights = None

0 commit comments

Comments
 (0)