Skip to content

Commit ff0ebc2

Browse files
committed
update split_param loading
1 parent 4ab0df1 commit ff0ebc2

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
224224
for shard_file in resolved_archive_file:
225225
if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]):
226226
continue
227-
228227
if model.config.tensor_parallel_degree > 1:
229-
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected")
228+
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="cpu")
230229
else:
231-
state_dict = load_state_dict(shard_file, None, expected_keys, device="expected")
232-
230+
state_dict = load_state_dict(shard_file, None, expected_keys, device="cpu")
233231
returned_state_dict.update(state_dict)
234232
del state_dict
235233
gc.collect()
@@ -238,13 +236,6 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
238236

239237
# get tp params
240238
state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys_optim)
241-
if has_master_weights:
242-
state_dict_master_weight = load_resolved_archive_file(
243-
resolved_archive_file_mw,
244-
sharded_metadata_mw,
245-
expected_keys,
246-
is_master_weights=True,
247-
)
248239

249240
# need to split param for different sharding rank, maybe need to deal with oom issue.
250241
for key in list(state_dict_optim.keys()):
@@ -266,15 +257,24 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
266257
paddle.zeros([padding_end - padding_start], dtype=state_dict_optim[key].dtype),
267258
)
268259
)
269-
270260
if has_master_weights:
271261
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
272262
else:
273263
key_name = "_".join([static_name, key_name[1]])
264+
265+
state_dict_optim[key] = state_dict_optim[key]._copy_to(paddle.framework._current_expected_place(), False)
266+
274267
returned_optim_state_dict[key_name] = state_dict_optim.pop(key)
275268
returned_optim_state_dict[key_name].name = key_name
276269

277270
if has_master_weights:
271+
state_dict_master_weight = load_resolved_archive_file(
272+
resolved_archive_file_mw,
273+
sharded_metadata_mw,
274+
expected_keys,
275+
is_master_weights=True,
276+
)
277+
278278
for key in list(state_dict_master_weight.keys()):
279279
static_name = struct2static_name_mappings.get(key, None)
280280
if state_dict_master_weight[key].numel().item() > 1:
@@ -292,6 +292,9 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
292292
paddle.zeros([padding_end - padding_start], dtype=state_dict_master_weight[key].dtype),
293293
)
294294
)
295+
state_dict_master_weight[key] = state_dict_master_weight[key]._copy_to(
296+
paddle.framework._current_expected_place(), False
297+
)
295298
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key)
296299
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])
297300

0 commit comments

Comments
 (0)