Skip to content

Commit 37f3be1

Browse files
committed
fix expert parallel
1 parent 3967f76 commit 37f3be1

File tree

2 files changed

+32
-20
lines changed

2 files changed

+32
-20
lines changed

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,15 @@ def unified_checkpoint_into_shards(
516516

517517
config_to_save = copy.deepcopy(model_to_save.config)
518518

519+
if args.use_expert_parallel:
520+
# ignore saving `no_sync=False` tensors when using expert_parallel under dp_rank > 0.
521+
hcg = fleet.get_hybrid_communicate_group()
522+
dp_group = hcg.get_data_parallel_group()
523+
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
524+
for key in list(state_dict.keys()):
525+
if dp_rank > 0 and not getattr(state_dict[key], "no_sync", False):
526+
state_dict.pop(key)
527+
519528
if config_to_save.tensor_parallel_degree > 1:
520529
if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM):
521530
tp_actions = model_to_save._get_tensor_parallel_convert_actions(
@@ -622,8 +631,25 @@ def unified_optimizer_into_shards(
622631
filter_master_keys = filter_params(model, master_weights, args, is_optimizer=True)
623632
filter_optim_keys = filter_params(model, optim_state_dict, args, is_optimizer=True)
624633

625-
tp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group()
634+
hcg = fleet.get_hybrid_communicate_group()
635+
tp_group = hcg.get_model_parallel_group()
636+
dp_group = hcg.get_data_parallel_group()
626637
tp_size = tp_group.nranks
638+
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
639+
640+
no_sync_kname = []
641+
if args.use_expert_parallel:
642+
for k, v in state_dict.items():
643+
if getattr(state_dict[k], "no_sync", False):
644+
no_sync_kname.append(k)
645+
for key in list(optim_state_dict.keys()):
646+
model_key = key.split("/")[0]
647+
if dp_rank > 0 and model_key not in no_sync_kname:
648+
optim_state_dict.pop(key)
649+
if master_weights is not None:
650+
for key in list(master_weights.keys()):
651+
if dp_rank > 0 and key not in no_sync_kname:
652+
master_weights.pop(key)
627653

628654
if tp_size > 1:
629655
# get tp_actions
@@ -643,7 +669,6 @@ def unified_optimizer_into_shards(
643669
optim_state_dict,
644670
tp_actions,
645671
filter_optim_keys,
646-
state_dict if args.use_expert_parallel else None,
647672
)
648673
paddle.device.cuda.empty_cache()
649674

@@ -653,7 +678,6 @@ def unified_optimizer_into_shards(
653678
master_weights,
654679
tp_actions,
655680
filter_master_keys,
656-
state_dict if args.use_expert_parallel else None,
657681
)
658682
paddle.device.cuda.empty_cache()
659683

paddlenlp/trainer/unified_checkpoint/utils.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -354,9 +354,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
354354
"""
355355
hcg = fleet.get_hybrid_communicate_group()
356356
tp_group = hcg.get_model_parallel_group()
357-
dp_group = hcg.get_data_parallel_group()
358357
tp_rank = tp_group.rank
359-
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
360358

361359
# filter actions for pipeline mode
362360
if hcg.get_pipe_parallel_group().nranks > 1:
@@ -373,10 +371,9 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
373371
if i > len(filter_keys) - 1:
374372
continue
375373
key = filter_keys[i]
376-
tensor = state_dict[key]
377-
# When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0.
378-
if dp_rank > 0 and not getattr(tensor, "no_sync", False):
374+
if key not in state_dict:
379375
continue
376+
tensor = state_dict[key]
380377
if key in tp_actions:
381378
# Get tensor size
382379
tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks
@@ -405,21 +402,13 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
405402
return state_dict_to_save
406403

407404

408-
def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys, model_state_dict=None):
405+
def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys):
409406
"""
410407
Merge tensor parallel according to tp_actions, used for master_weight and optimizer weight.
411408
"""
412409
hcg = fleet.get_hybrid_communicate_group()
413410
tp_group = hcg.get_model_parallel_group()
414-
dp_group = hcg.get_data_parallel_group()
415411
tp_rank = tp_group.rank
416-
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
417-
418-
no_sync_kname = []
419-
if model_state_dict is not None:
420-
for k, v in model_state_dict.items():
421-
if getattr(v, "no_sync", False):
422-
no_sync_kname.append(k)
423412

424413
state_dict_to_save = {}
425414
max_key_len = max([len(_) for _ in all_filter_keys])
@@ -430,10 +419,9 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys,
430419
continue
431420
# get base model key
432421
model_key = filter_keys[i].split("/")[0]
433-
tensor = state_dict[filter_keys[i]]
434-
# When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0.
435-
if dp_rank > 0 and model_key not in no_sync_kname:
422+
if filter_keys[i] not in state_dict:
436423
continue
424+
tensor = state_dict[filter_keys[i]]
437425
if model_key in tp_actions:
438426
# for example: beta1, beta2
439427
if tensor.numel().item() == 1:

0 commit comments

Comments
 (0)