Skip to content

Commit cd14081

Browse files
committed
[fea] moe support
1 parent 0cd8fe7 commit cd14081

File tree

4 files changed

+78
-15
lines changed

4 files changed

+78
-15
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -945,7 +945,8 @@ def _inner_training_loop(
945945
((step_control + 1) % args.gradient_accumulation_steps != 0)
946946
and availiable_no_sync
947947
and args._no_sync_in_gradient_accumulation
948-
) or (args.recompute and availiable_no_sync)
948+
) or (args.recompute and availiable_no_sync
949+
) or (args.use_moe and availiable_no_sync)
949950
# sharding
950951
# stage1. the same as ddp
951952
# stage2. manualy collect gradient on dp group
@@ -965,6 +966,14 @@ def _inner_training_loop(
965966

966967
tr_loss += tr_loss_step
967968

969+
def fused_allreduce_gradients_no_sync(paramlist, hcg):
970+
paramlist = list(paramlist)
971+
nonmoe_list = [p for p in paramlist if not getattr(p, "no_sync", False)]
972+
moelist = [p for p in paramlist if getattr(p, "no_sync", False)]
973+
if moelist and not self.args.use_moe:
974+
logger.warning("found `no sync` param when `use_moe=False`")
975+
fused_allreduce_gradients(nonmoe_list, hcg)
976+
968977
if (step_control + 1) % args.gradient_accumulation_steps == 0 or (
969978
# last step in epoch but step is always smaller than gradient_accumulation_steps
970979
steps_in_epoch <= args.gradient_accumulation_steps
@@ -983,12 +992,12 @@ def _inner_training_loop(
983992

984993
# Case 1: Use recompute and dp / sharding stage1,
985994
# manualy collect gradient for dp.
986-
if args.recompute and availiable_no_sync:
987-
fused_allreduce_gradients(list(model.parameters()), None)
995+
if (args.recompute or args.use_moe) and availiable_no_sync:
996+
fused_allreduce_gradients_no_sync(list(model.parameters()), None)
988997

989998
# Case 2: hack dp with master_grad
990-
if dp_master_grad and not (args.recompute and availiable_no_sync):
991-
fused_allreduce_gradients(list(model.parameters()), None)
999+
elif dp_master_grad:
1000+
fused_allreduce_gradients_no_sync(list(model.parameters()), None)
9921001

9931002
# Pipeline parallel mode, handle gradient reduce here to overlap
9941003
pipeline_parallel_config = (
@@ -1007,8 +1016,7 @@ def _inner_training_loop(
10071016
self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg)
10081017

10091018
if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False):
1010-
fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg)
1011-
1019+
fused_allreduce_gradients_no_sync(list(parameters_list), self.optimizer._hcg)
10121020
self.timers and self.timers("all-reduce").stop()
10131021
self.timers and self.timers("optimizer-step").start()
10141022

@@ -1028,7 +1036,9 @@ def _inner_training_loop(
10281036
)
10291037
optimizer_was_run = True
10301038
if self.do_grad_scaling:
1031-
scale_before = paddle.assign(self.scaler._scale)
1039+
if args.pipeline_parallel_degree > 1:
1040+
assert not self.args.use_moe, "pipline moe not work under fp16"
1041+
scale_before = self.scaler._scale.numpy()
10321042
self.scaler.step(self.optimizer)
10331043
self.scaler.update()
10341044
scale_after = self.scaler._scale
@@ -2042,7 +2052,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
20422052

20432053
model.train()
20442054
inputs = self._prepare_inputs(inputs)
2045-
2055+
self.timers and self.timers(f"forward-acc-{self._cur_acc_step}").start()
20462056
with self.autocast_smart_context_manager():
20472057
loss = self.compute_loss(model, inputs)
20482058

@@ -2053,7 +2063,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
20532063
self.scaler.scale(loss).backward()
20542064
else:
20552065
loss.backward()
2056-
2066+
self.timers and self.timers(f"backward-acc-{self._cur_acc_step}").stop()
20572067
return loss.detach()
20582068

20592069
def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
@@ -2142,6 +2152,18 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
21422152
if self.args.should_save_model_state and self.args.should_save:
21432153
# For ckpt integrity
21442154
paddle.save(self.state.global_step, os.path.join(output_dir, ".model_done"))
2155+
def _save_moe_weights(
2156+
self,
2157+
output_dir: Optional[str] = None,
2158+
merge_tensor_parallel: Optional[bool] = False,):
2159+
# save moe optimizer and model state # TODO 默认为冗余存储
2160+
2161+
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)
2162+
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
2163+
saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}")
2164+
paddle.save(self.optimizer.state_dict(), os.path.join(output_dir, optimizer_name))
2165+
with open(saved_signal_path, mode="w+") as f:
2166+
f.write("1")
21452167

21462168
def _save_checkpoint(self, model, metrics=None):
21472169
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
@@ -2245,6 +2267,8 @@ def _save_checkpoint(self, model, metrics=None):
22452267
os.makedirs(output_dir, exist_ok=True)
22462268
paddle.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
22472269

2270+
if self.args.use_moe and self.args.data_parallel_rank > 0:
2271+
self._save_moe_weights(output_dir)
22482272
# Maybe delete some older checkpoints.
22492273
# For hybrid parallel training, the checkpoint files maybe on different node.
22502274
need_to_rotate_checkpoints = False

paddlenlp/trainer/training_args.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,10 @@ class TrainingArguments:
803803
default=False,
804804
metadata={"help": "whether to run distributed training in auto parallel mode"},
805805
)
806+
use_moe: Optional[bool] = field(
807+
default=False,
808+
metadata={"help": "开启moe训练"},
809+
)
806810

807811
def __post_init__(self):
808812
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
@@ -1149,6 +1153,8 @@ def is_segment_parallel_supported():
11491153
order = ["dp", "sharding", "pp", "sep", "mp"]
11501154
else:
11511155
order = ["dp", "sharding", "pp", "mp"]
1156+
if self.use_moe:
1157+
order = order[1: -1] + ["dp", "mp"]
11521158

11531159
if is_segment_parallel_supported():
11541160
hybrid_configs = {
@@ -1640,8 +1646,12 @@ def optimizer_name_suffix(self):
16401646
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
16411647
if self.sharding_parallel_degree > 1:
16421648
name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree))
1649+
if self.use_moe:
1650+
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
16431651
return "_".join(name)
16441652
else:
1653+
if self.use_moe:
1654+
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)
16451655
return None
16461656

16471657
@property
@@ -1652,12 +1662,16 @@ def weight_name_suffix(self):
16521662
name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree))
16531663
if self.pipeline_parallel_degree > 1:
16541664
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
1665+
if self.use_moe:
1666+
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
16551667
return "_".join(name)
16561668

16571669
else:
1670+
if self.use_moe:
1671+
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)
16581672
return None
16591673

1660-
def sharded_name_suffix(self, shard_id=None, pp_id=None):
1674+
def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):
16611675
if self.use_hybrid_parallel:
16621676
name = []
16631677
if self.tensor_parallel_degree > 1:
@@ -1672,8 +1686,17 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None):
16721686
shard_id = self.sharding_parallel_rank
16731687
assert isinstance(shard_id, int)
16741688
name.append(self._format_name("shard", shard_id, self.sharding_parallel_degree))
1689+
if self.use_moe:
1690+
if moe_id is None:
1691+
moe_id = self.data_parallel_rank
1692+
assert isinstance(moe_id, int)
1693+
name.append(self._format_name("moe", moe_id, self.data_parallel_degree))
16751694
return "_".join(name)
16761695
else:
1696+
if self.use_moe:
1697+
if moe_id is None:
1698+
moe_id = self.data_parallel_rank
1699+
return self._format_name("moe", moe_id, self.data_parallel_degree)
16771700
return None
16781701

16791702
@property

paddlenlp/trainer/utils/sharding_io.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -444,12 +444,22 @@ def filter_func(name):
444444

445445
master_weights = reshard_util.all_gather_state_dict(master_weights, filter_func, self.sharding_group)
446446
model_state_dict = self.model.state_dict()
447+
logger.info(f"state-dict-keys: {state_dict.keys()}, nums: {len(state_dict.keys())}")
447448
logger.info("before recover, model_state_dict number: {}".format(len(model_state_dict)))
448449
for key, param in model_state_dict.items():
449-
if param.name in master_weights:
450-
assert param.shape == master_weights[param.name].shape
451-
paddle.assign(master_weights[param.name].cuda(), model_state_dict[key])
452-
450+
if param.name in master_weigths:
451+
assert param.shape == master_weigths[param.name].shape
452+
paddle.assign(paddle.cast(master_weigths[param.name].cuda(), paddle.bfloat16), model_state_dict[key])
453+
elif key in state_dict:
454+
logger.info(f"key: {key} is in state_dict, but not in master_weights")
455+
paddle.assign(state_dict[key], model_state_dict[key])
456+
if param.name in sharding_group_param_names:
457+
paddle.distributed.broadcast(
458+
model_state_dict[key],
459+
src=self.sharding_group.ranks[param2rank[param.name]],
460+
group=self.sharding_group,
461+
sync_op=True,
462+
)
453463
logger.info("after recover, casted model_state_dict number: {}".format(len(model_state_dict)))
454464
state_dict.update(model_state_dict)
455465
return state_dict

paddlenlp/transformers/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,8 +818,14 @@ def weight_name_suffix():
818818
name.append(f"tp{hcg.get_model_parallel_rank():0>2d}")
819819
if hcg.get_pipe_parallel_world_size() > 1:
820820
name.append(f"pp{hcg.get_stage_id():0>2d}")
821+
if config and getattr(config, "moe_num_experts", 0):
822+
dp_group = hcg.get_data_parallel_group()
823+
name.append(f"moe{dp_group.rank:0>2d}")
821824
return "_".join(name)
822825
else:
826+
if config and getattr(config, "moe_num_experts", 0):
827+
rank = paddle.distributed.get_rank()
828+
return f"moe{rank:0>2d}"
823829
return None
824830

825831

0 commit comments

Comments
 (0)