Skip to content

Commit c807bd8

Browse files
committed
[fea] moe support (PaddlePaddle#8498)
Co-authored-by: kebo01 <kebo01@baidu.com>
1 parent 82a7177 commit c807bd8

File tree

6 files changed

+192
-44
lines changed

6 files changed

+192
-44
lines changed

docs/trainer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,4 +705,8 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
705705
Whether use flatten_param_grads method in optimizer,
706706
only used on NPU devices.(default:False)
707707
708+
--use_expert_parallel
709+
Whether to enable MoE (Mixture of Experts) expert parallel training.
710+
(default: False)
711+
708712
```

paddlenlp/trainer/trainer.py

Lines changed: 69 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@
143143
from .utils import reshard as reshard_util
144144
from .utils.helper import ( # nested_truncate,
145145
broadcast_dp_optimizer,
146+
broadcast_moe_optimizer,
146147
distributed_concat,
147148
distributed_file,
148149
distributed_isfile,
@@ -565,7 +566,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
565566
)
566567
self.model.set_state_dict(state_dict)
567568
else:
568-
if resume_from_checkpoint is not None and self.args.dataset_rank == 0:
569+
if resume_from_checkpoint is not None and (self.args.dataset_rank == 0 or self.args.use_expert_parallel):
569570

570571
weights_file = os.path.join(
571572
resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix)
@@ -930,22 +931,17 @@ def _inner_training_loop(
930931
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
931932
self.timers and self.timers("forward-backward").start()
932933

933-
dp_enabled = (
934-
self.args.data_parallel_degree > 1 if self.args.use_hybrid_parallel else args.local_rank != -1
935-
)
936-
forbidden_no_sync = False
937934
# stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API
938935
# hybrid_parallel (tp or pp or sharding stage 1) should not no_sync
939-
if self.args.use_hybrid_parallel:
940-
forbidden_no_sync = True
941-
942-
availiable_no_sync = dp_enabled and not forbidden_no_sync
943-
936+
availiable_no_sync = hasattr(model, "no_sync")
944937
is_no_sync = (
945-
((step_control + 1) % args.gradient_accumulation_steps != 0)
946-
and availiable_no_sync
947-
and args._no_sync_in_gradient_accumulation
948-
) or (args.recompute and availiable_no_sync)
938+
(
939+
((step_control + 1) % args.gradient_accumulation_steps != 0)
940+
and args._no_sync_in_gradient_accumulation
941+
)
942+
or args.recompute
943+
or args.use_expert_parallel
944+
) and availiable_no_sync
949945
# sharding
950946
# stage1. the same as ddp
951947
# stage2. manualy collect gradient on dp group
@@ -965,6 +961,14 @@ def _inner_training_loop(
965961

966962
tr_loss += tr_loss_step
967963

964+
def fused_allreduce_gradients_no_sync(paramlist, hcg):
965+
paramlist = list(paramlist)
966+
nonmoe_list = [p for p in paramlist if not getattr(p, "no_sync", False)]
967+
moelist = [p for p in paramlist if getattr(p, "no_sync", False)]
968+
if moelist and not self.args.use_expert_parallel:
969+
logger.warning("found `no sync` param when `use_expert_parallel=False`")
970+
fused_allreduce_gradients(nonmoe_list, hcg)
971+
968972
if (step_control + 1) % args.gradient_accumulation_steps == 0 or (
969973
# last step in epoch but step is always smaller than gradient_accumulation_steps
970974
steps_in_epoch <= args.gradient_accumulation_steps
@@ -983,12 +987,12 @@ def _inner_training_loop(
983987

984988
# Case 1: Use recompute and dp / sharding stage1,
985989
# manualy collect gradient for dp.
986-
if args.recompute and availiable_no_sync:
987-
fused_allreduce_gradients(list(model.parameters()), None)
990+
if (args.recompute or args.use_expert_parallel) and availiable_no_sync:
991+
fused_allreduce_gradients_no_sync(list(model.parameters()), None)
988992

989993
# Case 2: hack dp with master_grad
990-
if dp_master_grad and not (args.recompute and availiable_no_sync):
991-
fused_allreduce_gradients(list(model.parameters()), None)
994+
elif dp_master_grad:
995+
fused_allreduce_gradients_no_sync(list(model.parameters()), None)
992996

993997
# Pipeline parallel mode, handle gradient reduce here to overlap
994998
pipeline_parallel_config = (
@@ -1007,8 +1011,7 @@ def _inner_training_loop(
10071011
self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg)
10081012

10091013
if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False):
1010-
fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg)
1011-
1014+
fused_allreduce_gradients_no_sync(list(parameters_list), self.optimizer._hcg)
10121015
self.timers and self.timers("all-reduce").stop()
10131016
self.timers and self.timers("optimizer-step").start()
10141017

@@ -1028,6 +1031,8 @@ def _inner_training_loop(
10281031
)
10291032
optimizer_was_run = True
10301033
if self.do_grad_scaling:
1034+
if args.pipeline_parallel_degree > 1:
1035+
assert not self.args.use_expert_parallel, "pipeline moe not work under fp16"
10311036
scale_before = paddle.assign(self.scaler._scale)
10321037
self.scaler.step(self.optimizer)
10331038
self.scaler.update()
@@ -2042,7 +2047,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
20422047

20432048
model.train()
20442049
inputs = self._prepare_inputs(inputs)
2045-
20462050
with self.autocast_smart_context_manager():
20472051
loss = self.compute_loss(model, inputs)
20482052

@@ -2053,7 +2057,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
20532057
self.scaler.scale(loss).backward()
20542058
else:
20552059
loss.backward()
2056-
20572060
return loss.detach()
20582061

20592062
def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
@@ -2143,6 +2146,26 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
21432146
# For ckpt integrity
21442147
paddle.save(self.state.global_step, os.path.join(output_dir, ".model_done"))
21452148

2149+
def _filter_moe_no_sync_optimizer_params(self):
2150+
"""
2151+
filter optimizer params which should not sync
2152+
"""
2153+
state_dict = self.model.state_dict()
2154+
optimzier_state_dict = self.optimizer.state_dict()
2155+
filter_optimzier_state_dict = OrderedDict()
2156+
param_names_in_master_weights = list(optimzier_state_dict["master_weights"].keys()) if self.args.bf16 else []
2157+
filter_optimzier_state_dict["master_weights"] = OrderedDict()
2158+
for k, v in state_dict.items():
2159+
if getattr(v, "no_sync", False):
2160+
if v.name in param_names_in_master_weights:
2161+
filter_optimzier_state_dict["master_weights"][v.name] = optimzier_state_dict["master_weights"][
2162+
v.name
2163+
]
2164+
for op_k, op_v in optimzier_state_dict.items():
2165+
if op_k.startswith(v.name):
2166+
filter_optimzier_state_dict[op_k] = op_v
2167+
return filter_optimzier_state_dict
2168+
21462169
def _save_checkpoint(self, model, metrics=None):
21472170
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
21482171
self.runtime_timer.start("checkpoint saving time")
@@ -2165,7 +2188,7 @@ def _save_checkpoint(self, model, metrics=None):
21652188
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
21662189

21672190
if self.args.use_hybrid_parallel:
2168-
if self.dp_group.rank <= 0:
2191+
if self.dp_group.rank <= 0 or self.args.use_expert_parallel:
21692192
os.makedirs(output_dir, exist_ok=True)
21702193
logger.info("Saving optimizer files.")
21712194
if self.args.unified_checkpoint:
@@ -2177,12 +2200,18 @@ def _save_checkpoint(self, model, metrics=None):
21772200
safe_serialization=True,
21782201
)
21792202
else:
2180-
self._save_ckpt_func(
2181-
self.optimizer.state_dict(),
2182-
os.path.join(output_dir, optimizer_name),
2183-
)
2203+
if self.dp_group.rank > 0: # this should only work for MoE saving
2204+
self._save_ckpt_func(
2205+
self._filter_moe_no_sync_optimizer_params(),
2206+
os.path.join(output_dir, optimizer_name),
2207+
)
2208+
else:
2209+
self._save_ckpt_func(
2210+
self.optimizer.state_dict(),
2211+
os.path.join(output_dir, optimizer_name),
2212+
)
21842213

2185-
if self.args.should_save:
2214+
if self.args.should_save or self.args.use_expert_parallel:
21862215
if not self.args.use_hybrid_parallel:
21872216
logger.info("Saving optimizer files.")
21882217
if self.args.unified_checkpoint:
@@ -2194,7 +2223,12 @@ def _save_checkpoint(self, model, metrics=None):
21942223
safe_serialization=True,
21952224
)
21962225
else:
2197-
self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2226+
if self.dp_group.rank > 0:
2227+
self._save_ckpt_func(
2228+
self._filter_moe_no_sync_optimizer_params(), os.path.join(output_dir, OPTIMIZER_NAME)
2229+
)
2230+
else:
2231+
self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
21982232

21992233
# FIXME: maybe only save one copy
22002234
paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
@@ -2452,7 +2486,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
24522486
logger.info("Loading checkpoint, the next checkpoint will be saved as unified checkpoint")
24532487

24542488
if not use_unified_checkpoint:
2455-
if self.args.data_parallel_rank == 0:
2489+
if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel:
24562490
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
24572491
path = os.path.join(checkpoint, optimizer_name)
24582492
if os.path.isfile(path):
@@ -2476,7 +2510,11 @@ def _load_optimizer_and_scheduler(self, checkpoint):
24762510
# broadcast optimizer state in dp group
24772511
if self.args.local_rank != -1:
24782512
dist.barrier()
2479-
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)
2513+
if self.args.use_expert_parallel:
2514+
opt_state_dict = broadcast_moe_optimizer(opt_state_dict)
2515+
else:
2516+
if not self.args.should_load_sharding_stage1_model:
2517+
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)
24802518

24812519
if opt_state_dict is not None:
24822520
# Load in optimizer and scheduler states

paddlenlp/trainer/training_args.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,10 @@ class TrainingArguments:
791791
default=False,
792792
metadata={"help": "whether to output logits in distributed status"},
793793
)
794+
use_expert_parallel: Optional[bool] = field(
795+
default=False,
796+
metadata={"help": "Enable MoE (Mixture of Experts) expert parallel training"},
797+
)
794798

795799
def __post_init__(self):
796800
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
@@ -1117,6 +1121,8 @@ def is_segment_parallel_supported():
11171121
order = ["dp", "sharding", "pp", "sep", "mp"]
11181122
else:
11191123
order = ["dp", "sharding", "pp", "mp"]
1124+
if self.use_expert_parallel:
1125+
order = order[1:-1] + ["dp", "mp"]
11201126

11211127
if is_segment_parallel_supported():
11221128
hybrid_configs = {
@@ -1598,9 +1604,12 @@ def optimizer_name_suffix(self):
15981604
if self.sharding_parallel_degree > 1:
15991605
assert self.sharding_parallel_degree < 100, "sharding parallel degree should be less than 100."
16001606
name.append(f"shard{self.sharding_parallel_rank:0>2d}")
1601-
1607+
if self.use_expert_parallel:
1608+
name.append(f"moe{self.data_parallel_rank:0>2d}")
16021609
return "_".join(name)
16031610
else:
1611+
if self.use_expert_parallel:
1612+
return f"moe{self.data_parallel_rank:0>2d}"
16041613
return None
16051614

16061615
@property
@@ -1613,12 +1622,16 @@ def weight_name_suffix(self):
16131622
if self.pipeline_parallel_degree > 1:
16141623
assert self.pipeline_parallel_degree < 100, "tensor parallel rank should be less than 100."
16151624
name.append(f"pp{self.pipeline_parallel_rank:0>2d}")
1625+
if self.use_expert_parallel:
1626+
name.append(f"moe{self.data_parallel_rank:0>2d}")
16161627
return "_".join(name)
16171628

16181629
else:
1630+
if self.use_expert_parallel:
1631+
return f"moe{self.data_parallel_rank:0>2d}"
16191632
return None
16201633

1621-
def sharded_name_suffix(self, shard_id=None, pp_id=None):
1634+
def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):
16221635
if self.use_hybrid_parallel:
16231636
name = []
16241637
if self.tensor_parallel_degree > 1:
@@ -1636,8 +1649,17 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None):
16361649
assert isinstance(shard_id, int)
16371650
assert shard_id < 100, "shard_id should be less than 100."
16381651
name.append(f"shard{shard_id:0>2d}")
1652+
if self.use_expert_parallel:
1653+
if moe_id is None:
1654+
moe_id = self.data_parallel_rank
1655+
assert isinstance(moe_id, int)
1656+
name.append(f"moe{moe_id:0>2d}")
16391657
return "_".join(name)
16401658
else:
1659+
if self.use_expert_parallel:
1660+
if moe_id is None:
1661+
moe_id = self.data_parallel_rank
1662+
return self._format_name("moe", moe_id, self.data_parallel_degree)
16411663
return None
16421664

16431665
@property
@@ -1730,9 +1752,9 @@ def should_save_model_state(self):
17301752
return True
17311753
elif self.use_hybrid_parallel:
17321754
# save on dataset rank 0
1733-
return self.sharding_parallel_rank == 0 and self.data_parallel_rank == 0
1755+
return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel)
17341756
else:
1735-
return self.process_index == 0
1757+
return self.process_index == 0 or self.use_expert_parallel
17361758

17371759
@property
17381760
def _no_sync_in_gradient_accumulation(self):

paddlenlp/trainer/utils/helper.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,59 @@ def broadcast_dp_optimizer(state_dict):
226226
state_dict = nested_broadcast_tensor(state_dict, src=src_rank, group=dp_group)
227227

228228
return state_dict
229+
230+
231+
def broadcast_moe_optimizer(state_dict):
232+
233+
try:
234+
hcg = fleet.get_hybrid_communicate_group()
235+
dp_group = hcg.get_data_parallel_group()
236+
src_rank = hcg.get_data_parallel_group_src_rank()
237+
data_parallel_rank = hcg.get_data_parallel_rank()
238+
# Don't broadcast optimizer for dp rank is 1.
239+
if dp_group.nranks <= 1:
240+
return state_dict
241+
except:
242+
dp_group = None
243+
src_rank = 0
244+
data_parallel_rank = 0
245+
246+
def _broadcast_moe_optimizer_state(state_dict):
247+
# boardcast_keys
248+
base_state_dict = {"master_weights": {}}
249+
buf = [
250+
{i: j.shape for i, j in state_dict.items() if i not in ["master_weights", "LR_Scheduler"]},
251+
{i: j.shape for i, j in state_dict["master_weights"].items()},
252+
{"LR_Scheduler": state_dict.get("LR_Scheduler", {})},
253+
]
254+
255+
dist.broadcast_object_list(buf, src=src_rank, group=dp_group)
256+
# logger.info(f"moe-optimizer-gather-keys{buf}")
257+
for k, s in buf[0].items():
258+
v = state_dict.get(k, paddle.zeros(s, "float32")).cuda()
259+
v.name = k
260+
# k = k.replace("_fp32_master_0", "")
261+
dist.broadcast(v, src=src_rank, group=dp_group)
262+
logger.info(f"broadcast moe optimizer {k} from {src_rank}")
263+
base_state_dict[k] = v.cpu()
264+
for k, s in buf[1].items():
265+
v = state_dict["master_weights"].get(k, paddle.zeros(s, "float32")).cuda()
266+
v.name = k
267+
dist.broadcast(v, src=src_rank, group=dp_group)
268+
logger.info(f"broadcast moe optimizer-master_weights {k} from {src_rank}")
269+
base_state_dict["master_weights"][k] = v.cpu()
270+
base_state_dict.update(buf[2])
271+
return base_state_dict
272+
273+
base_state_dict = _broadcast_moe_optimizer_state(state_dict)
274+
if data_parallel_rank > 0:
275+
master_weight = state_dict.pop("master_weights", {})
276+
base_state_dict.update(state_dict)
277+
if master_weight:
278+
if "master_weights" in base_state_dict:
279+
base_state_dict["master_weights"].update(master_weight)
280+
else:
281+
base_state_dict["master_weights"] = master_weight
282+
state_dict = base_state_dict
283+
del base_state_dict
284+
return state_dict

0 commit comments

Comments
 (0)