Skip to content

Commit 8bed006

Browse files
committed
fix split_param for expert parallel
1 parent 37f3be1 commit 8bed006

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,11 @@ def load_resolved_archive_file(
305305
)
306306
)
307307
if has_master_weights:
308-
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
308+
if model_state_dict[key_name[0]].dtype != paddle.float32:
309+
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
310+
else:
311+
# for moe gate with float32 dtype.
312+
key_name = "_".join([static_name, key_name[1]])
309313
else:
310314
key_name = "_".join([static_name, key_name[1]])
311315

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,8 +637,8 @@ def unified_optimizer_into_shards(
637637
tp_size = tp_group.nranks
638638
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
639639

640-
no_sync_kname = []
641640
if args.use_expert_parallel:
641+
no_sync_kname = []
642642
for k, v in state_dict.items():
643643
if getattr(state_dict[k], "no_sync", False):
644644
no_sync_kname.append(k)

0 commit comments

Comments
 (0)