Skip to content

Commit 7f806e8

Browse files
committed
fix split_param
1 parent 2f85a64 commit 7f806e8

File tree

3 files changed

+70
-29
lines changed

3 files changed

+70
-29
lines changed

paddlenlp/trainer/unified_checkpoint/load_local.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,6 @@ def _remove_unused_keys(
149149

150150

151151
def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False):
152-
# Special process with split param.
153-
if is_sharding_split_param_mode(args):
154-
returned_optim_state_dict = load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint)
155-
return returned_optim_state_dict
156-
157-
# init and get optimizer LR_Scheduler
158-
returned_optim_state_dict = nested_copy(optimizer.state_dict())
159-
160152
if not safe_serialization:
161153
index_filename, index_filename_master_weights = (
162154
PADDLE_OPTIMIZER_INDEX_NAME,
@@ -165,6 +157,23 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
165157
else:
166158
index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME
167159

160+
with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f:
161+
index = json.loads(f.read())
162+
163+
ckpt_quant_stage = "O0"
164+
if "ckpt_quant_stage" in index:
165+
ckpt_quant_stage = index["ckpt_quant_stage"]
166+
167+
# Special process with split param.
168+
if is_sharding_split_param_mode(args):
169+
returned_optim_state_dict = load_unified_optimizer_split_param(
170+
args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage
171+
)
172+
return returned_optim_state_dict
173+
174+
# init and get optimizer LR_Scheduler
175+
returned_optim_state_dict = nested_copy(optimizer.state_dict())
176+
168177
resolved_archive_file, sharded_metadata = get_optimizer_shard_files(
169178
optimizer_path=resume_from_checkpoint,
170179
index_filename=os.path.join(resume_from_checkpoint, index_filename),
@@ -184,13 +193,6 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
184193
if len(resolved_archive_file) > 1:
185194
resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")
186195

187-
with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f:
188-
index = json.loads(f.read())
189-
190-
ckpt_quant_stage = "O0"
191-
if "ckpt_quant_stage" in index:
192-
ckpt_quant_stage = index["ckpt_quant_stage"]
193-
194196
# update has_master_weights and index_filename_master_weights
195197
# 1. if the master weight exists, only has_master_weights is set True and loaded when needed
196198
# 2. if master weight does not exist, convert model weight to master weight when needed

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,18 @@
4242

4343

4444
def merge_splited_param(
45-
state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False
45+
state_dict,
46+
partial_tensor_list,
47+
param_shape_info,
48+
send_table,
49+
recv_table,
50+
is_master_weights=False,
51+
ckpt_quant_stage="O0",
4652
):
4753
"""Merge the splited param in sharding group."""
4854
global_rank = dist.get_rank()
4955
for key in list(state_dict.keys()):
50-
if state_dict[key].numel().item() == 1: # for example: beta1, beta2
56+
if int(state_dict[key].numel()) == 1: # for example: beta1, beta2
5157
continue
5258

5359
static_name = key if is_master_weights else generate_base_static_name(key)[0]
@@ -89,10 +95,21 @@ def merge_splited_param(
8995
)
9096
dist.stream.send(tensor, dst=recv_rank)
9197
state_dict.pop(key)
98+
99+
if ckpt_quant_stage != "O0":
100+
for key in list(state_dict.keys()):
101+
if int(state_dict[key].numel()) == 1: # for example: beta1, beta2
102+
static_name = key if is_master_weights else generate_base_static_name(key)[0]
103+
if static_name in partial_tensor_list:
104+
recv_rank = recv_table[static_name]
105+
send_info = send_table[static_name]
106+
if global_rank != recv_rank:
107+
state_dict.pop(key)
108+
92109
return state_dict
93110

94111

95-
def gather_splited_param_for_optimizer(optimizer):
112+
def gather_splited_param_for_optimizer(optimizer, ckpt_quant_stage="O0"):
96113
hcg = fleet.get_hybrid_communicate_group()
97114
sharding_group = hcg.get_sharding_parallel_group()
98115
global_rank = dist.get_rank()
@@ -127,7 +144,7 @@ def gather_splited_param_for_optimizer(optimizer):
127144
for key in list(optim_state_dict.keys()):
128145
static_name, _ = generate_base_static_name(key)
129146
if static_name in param_slice_info.keys():
130-
if optim_state_dict[key].numel().item() == 1: # for example: beta1, beta2
147+
if int(optim_state_dict[key].numel()) == 1: # for example: beta1, beta2
131148
continue
132149
begin, end = param_slice_info[static_name]
133150
shape, numel, _, _ = param_shape_info[static_name]
@@ -149,13 +166,17 @@ def gather_splited_param_for_optimizer(optimizer):
149166
recv_table[key] = sharding_ranklist[0][0] # which sharding_rank to recv the splited tensor
150167
send_table[key] = [(rank, begin, end) for rank, begin, end in sharding_ranklist]
151168

152-
merge_splited_param(optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False)
169+
merge_splited_param(
170+
optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False, ckpt_quant_stage
171+
)
153172
if master_weights is not None:
154-
merge_splited_param(master_weights, partial_tensor_list, param_shape_info, send_table, recv_table, True)
173+
merge_splited_param(
174+
master_weights, partial_tensor_list, param_shape_info, send_table, recv_table, True, ckpt_quant_stage
175+
)
155176
return optim_state_dict, master_weights
156177

157178

158-
def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint):
179+
def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
159180
returned_optim_state_dict = nested_copy(optimizer.state_dict())
160181

161182
index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME
@@ -217,7 +238,9 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
217238
if len(resolved_archive_file_mw) > 1:
218239
resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards")
219240

220-
def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False):
241+
def load_resolved_archive_file(
242+
resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False, ckpt_quant_stage="O0"
243+
):
221244
returned_state_dict = {}
222245

223246
if model.config.tensor_parallel_degree > 1:
@@ -232,24 +255,38 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
232255
if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]):
233256
continue
234257
if model.config.tensor_parallel_degree > 1:
235-
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="cpu")
258+
state_dict = load_state_dict(
259+
shard_file,
260+
tp_actions,
261+
expected_keys,
262+
device="cpu",
263+
ckpt_quant_stage=ckpt_quant_stage,
264+
)
236265
else:
237-
state_dict = load_state_dict(shard_file, None, expected_keys, device="cpu")
266+
state_dict = load_state_dict(
267+
shard_file,
268+
None,
269+
expected_keys,
270+
device="cpu",
271+
ckpt_quant_stage=ckpt_quant_stage,
272+
)
238273
returned_state_dict.update(state_dict)
239274
del state_dict
240275
gc.collect()
241276

242277
return returned_state_dict
243278

244279
# get tp params
245-
state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys_optim)
280+
state_dict_optim = load_resolved_archive_file(
281+
resolved_archive_file, sharded_metadata, expected_keys_optim, ckpt_quant_stage=ckpt_quant_stage
282+
)
246283

247284
# need to split param for different sharding rank, maybe need to deal with oom issue.
248285
for key in list(state_dict_optim.keys()):
249286
key_name = key.split("/")
250287
static_name = struct2static_name_mappings.get(key_name[0], None)
251288

252-
if state_dict_optim[key].numel().item() > 1:
289+
if int(state_dict_optim[key].numel()) > 1:
253290
begin, end = param_slice_info[static_name]
254291
shape, numel, index, padded_size = param_shape_info[static_name]
255292
state_dict_optim[key] = state_dict_optim[key].reshape([-1])
@@ -284,7 +321,7 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
284321

285322
for key in list(state_dict_master_weight.keys()):
286323
static_name = struct2static_name_mappings.get(key, None)
287-
if state_dict_master_weight[key].numel().item() > 1:
324+
if int(state_dict_master_weight[key].numel()) > 1:
288325
begin, end = param_slice_info[static_name]
289326
shape, numel, index, padded_size = param_shape_info[static_name]
290327
state_dict_master_weight[key] = state_dict_master_weight[key].reshape([-1])

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,9 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
344344
return
345345

346346
if is_sharding_split_param_mode(self.args):
347-
optim_state_dict, master_weights = gather_splited_param_for_optimizer(optimizer)
347+
optim_state_dict, master_weights = gather_splited_param_for_optimizer(
348+
optimizer, self.args.ckpt_quant_stage if "quant_reach_limit" not in infohub else "O0"
349+
)
348350
else:
349351
optim_state_dict = nested_copy(optimizer.state_dict())
350352
master_weights = None

0 commit comments

Comments
 (0)