Skip to content

Commit 2851da8

Browse files
committed
[fea] moe support
1 parent 0cd8fe7 commit 2851da8

File tree

5 files changed

+215
-22
lines changed

5 files changed

+215
-22
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@
143143
from .utils import reshard as reshard_util
144144
from .utils.helper import ( # nested_truncate,
145145
broadcast_dp_optimizer,
146+
broadcast_moe_optimizer,
146147
distributed_concat,
147148
distributed_file,
148149
distributed_isfile,
@@ -945,7 +946,8 @@ def _inner_training_loop(
945946
((step_control + 1) % args.gradient_accumulation_steps != 0)
946947
and availiable_no_sync
947948
and args._no_sync_in_gradient_accumulation
948-
) or (args.recompute and availiable_no_sync)
949+
) or (args.recompute and availiable_no_sync
950+
) or (args.use_moe and availiable_no_sync)
949951
# sharding
950952
# stage1. the same as ddp
951953
# stage2. manualy collect gradient on dp group
@@ -965,6 +967,14 @@ def _inner_training_loop(
965967

966968
tr_loss += tr_loss_step
967969

970+
def fused_allreduce_gradients_no_sync(paramlist, hcg):
971+
paramlist = list(paramlist)
972+
nonmoe_list = [p for p in paramlist if not getattr(p, "no_sync", False)]
973+
moelist = [p for p in paramlist if getattr(p, "no_sync", False)]
974+
if moelist and not self.args.use_moe:
975+
logger.warning("found `no sync` param when `use_moe=False`")
976+
fused_allreduce_gradients(nonmoe_list, hcg)
977+
968978
if (step_control + 1) % args.gradient_accumulation_steps == 0 or (
969979
# last step in epoch but step is always smaller than gradient_accumulation_steps
970980
steps_in_epoch <= args.gradient_accumulation_steps
@@ -983,12 +993,12 @@ def _inner_training_loop(
983993

984994
# Case 1: Use recompute and dp / sharding stage1,
985995
# manualy collect gradient for dp.
986-
if args.recompute and availiable_no_sync:
987-
fused_allreduce_gradients(list(model.parameters()), None)
996+
if (args.recompute or args.use_moe) and availiable_no_sync:
997+
fused_allreduce_gradients_no_sync(list(model.parameters()), None)
988998

989999
# Case 2: hack dp with master_grad
990-
if dp_master_grad and not (args.recompute and availiable_no_sync):
991-
fused_allreduce_gradients(list(model.parameters()), None)
1000+
elif dp_master_grad:
1001+
fused_allreduce_gradients_no_sync(list(model.parameters()), None)
9921002

9931003
# Pipeline parallel mode, handle gradient reduce here to overlap
9941004
pipeline_parallel_config = (
@@ -1007,8 +1017,7 @@ def _inner_training_loop(
10071017
self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg)
10081018

10091019
if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False):
1010-
fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg)
1011-
1020+
fused_allreduce_gradients_no_sync(list(parameters_list), self.optimizer._hcg)
10121021
self.timers and self.timers("all-reduce").stop()
10131022
self.timers and self.timers("optimizer-step").start()
10141023

@@ -1028,7 +1037,9 @@ def _inner_training_loop(
10281037
)
10291038
optimizer_was_run = True
10301039
if self.do_grad_scaling:
1031-
scale_before = paddle.assign(self.scaler._scale)
1040+
if args.pipeline_parallel_degree > 1:
1041+
assert not self.args.use_moe, "pipline moe not work under fp16"
1042+
scale_before = self.scaler._scale.numpy()
10321043
self.scaler.step(self.optimizer)
10331044
self.scaler.update()
10341045
scale_after = self.scaler._scale
@@ -2042,7 +2053,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
20422053

20432054
model.train()
20442055
inputs = self._prepare_inputs(inputs)
2045-
2056+
self.timers and self.timers(f"forward-acc-{self._cur_acc_step}").start()
20462057
with self.autocast_smart_context_manager():
20472058
loss = self.compute_loss(model, inputs)
20482059

@@ -2053,7 +2064,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
20532064
self.scaler.scale(loss).backward()
20542065
else:
20552066
loss.backward()
2056-
2067+
self.timers and self.timers(f"backward-acc-{self._cur_acc_step}").stop()
20572068
return loss.detach()
20582069

20592070
def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
@@ -2143,6 +2154,19 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
21432154
# For ckpt integrity
21442155
paddle.save(self.state.global_step, os.path.join(output_dir, ".model_done"))
21452156

2157+
def _save_moe_weights(
2158+
self,
2159+
output_dir: Optional[str] = None,
2160+
merge_tensor_parallel: Optional[bool] = False,):
2161+
# save moe optimizer and model state # TODO 默认为冗余存储
2162+
2163+
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)
2164+
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
2165+
saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}")
2166+
paddle.save(self.optimizer.state_dict(), os.path.join(output_dir, optimizer_name))
2167+
with open(saved_signal_path, mode="w+") as f:
2168+
f.write("1")
2169+
21462170
def _save_checkpoint(self, model, metrics=None):
21472171
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
21482172
self.runtime_timer.start("checkpoint saving time")
@@ -2245,6 +2269,8 @@ def _save_checkpoint(self, model, metrics=None):
22452269
os.makedirs(output_dir, exist_ok=True)
22462270
paddle.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
22472271

2272+
if self.args.use_moe and self.args.data_parallel_rank > 0:
2273+
self._save_moe_weights(output_dir)
22482274
# Maybe delete some older checkpoints.
22492275
# For hybrid parallel training, the checkpoint files maybe on different node.
22502276
need_to_rotate_checkpoints = False
@@ -2476,7 +2502,10 @@ def _load_optimizer_and_scheduler(self, checkpoint):
24762502
# broadcast optimizer state in dp group
24772503
if self.args.local_rank != -1:
24782504
dist.barrier()
2479-
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)
2505+
if not self.args.use_moe:
2506+
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)
2507+
else:
2508+
opt_state_dict = broadcast_moe_optimizer(opt_state_dict)
24802509

24812510
if opt_state_dict is not None:
24822511
# Load in optimizer and scheduler states

paddlenlp/trainer/training_args.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,10 @@ class TrainingArguments:
803803
default=False,
804804
metadata={"help": "whether to run distributed training in auto parallel mode"},
805805
)
806+
use_moe: Optional[bool] = field(
807+
default=False,
808+
metadata={"help": "开启moe训练"},
809+
)
806810

807811
def __post_init__(self):
808812
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
@@ -1149,6 +1153,8 @@ def is_segment_parallel_supported():
11491153
order = ["dp", "sharding", "pp", "sep", "mp"]
11501154
else:
11511155
order = ["dp", "sharding", "pp", "mp"]
1156+
if self.use_moe:
1157+
order = order[1: -1] + ["dp", "mp"]
11521158

11531159
if is_segment_parallel_supported():
11541160
hybrid_configs = {
@@ -1640,8 +1646,12 @@ def optimizer_name_suffix(self):
16401646
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
16411647
if self.sharding_parallel_degree > 1:
16421648
name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree))
1649+
if self.use_moe:
1650+
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
16431651
return "_".join(name)
16441652
else:
1653+
if self.use_moe:
1654+
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)
16451655
return None
16461656

16471657
@property
@@ -1652,12 +1662,16 @@ def weight_name_suffix(self):
16521662
name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree))
16531663
if self.pipeline_parallel_degree > 1:
16541664
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
1665+
if self.use_moe:
1666+
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
16551667
return "_".join(name)
16561668

16571669
else:
1670+
if self.use_moe:
1671+
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)
16581672
return None
16591673

1660-
def sharded_name_suffix(self, shard_id=None, pp_id=None):
1674+
def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):
16611675
if self.use_hybrid_parallel:
16621676
name = []
16631677
if self.tensor_parallel_degree > 1:
@@ -1672,8 +1686,17 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None):
16721686
shard_id = self.sharding_parallel_rank
16731687
assert isinstance(shard_id, int)
16741688
name.append(self._format_name("shard", shard_id, self.sharding_parallel_degree))
1689+
if self.use_moe:
1690+
if moe_id is None:
1691+
moe_id = self.data_parallel_rank
1692+
assert isinstance(moe_id, int)
1693+
name.append(self._format_name("moe", moe_id, self.data_parallel_degree))
16751694
return "_".join(name)
16761695
else:
1696+
if self.use_moe:
1697+
if moe_id is None:
1698+
moe_id = self.data_parallel_rank
1699+
return self._format_name("moe", moe_id, self.data_parallel_degree)
16771700
return None
16781701

16791702
@property

paddlenlp/trainer/utils/helper.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import os
2020
from typing import Any, Optional
21-
21+
import copy
2222
import numpy as np
2323
import paddle
2424
import paddle.distributed as dist
@@ -226,3 +226,108 @@ def broadcast_dp_optimizer(state_dict):
226226
state_dict = nested_broadcast_tensor(state_dict, src=src_rank, group=dp_group)
227227

228228
return state_dict
229+
230+
# def broadcast_moe_optimizer(state_dict):
231+
# if paddle.distributed.get_world_size() <= 1:
232+
# return state_dict
233+
234+
# logger.info("Start broadcast optimizer in MoE(data) parallel group.")
235+
# try:
236+
# hcg = fleet.get_hybrid_communicate_group()
237+
# dp_group = hcg.get_data_parallel_group()
238+
# src_rank = hcg.get_data_parallel_group_src_rank()
239+
# process_rank = paddle.distributed.get_rank()
240+
# # Don't broadcast optimizer for dp rank is 1.
241+
# if dp_group.nranks <= 1:
242+
# return state_dict
243+
# except:
244+
# dp_group = None
245+
# src_rank = 0
246+
# process_rank = paddle.distributed.get_rank()
247+
248+
# if process_rank == src_rank:
249+
# if state_dict is None:
250+
# logger.warning(
251+
# f"Your local rank {paddle.distributed.get_rank()} must have a state_dict. dp_rank:{process_rank}, src_rank:{src_rank}"
252+
# )
253+
# fake_state_dict = [nested_reduce_tensor(state_dict)]
254+
# else:
255+
# fake_state_dict = [None]
256+
257+
# paddle.distributed.broadcast_object_list(
258+
# fake_state_dict,
259+
# src=src_rank,
260+
# group=dp_group,
261+
# )
262+
# fake_state_dict = fake_state_dict[0]
263+
# if process_rank != src_rank:
264+
# sync_state_dict = nested_empty_tensor(fake_state_dict)
265+
# else:
266+
# sync_state_dict = state_dict
267+
# logger.info(f"SYNC-state-dict--{sync_state_dict.keys()}")
268+
# sync_state_dict = nested_broadcast_tensor(sync_state_dict, src=src_rank, group=dp_group)
269+
# if process_rank != src_rank:
270+
# master_weights = state_dict.pop('master_weights', {})
271+
# sync_state_dict['master_weights'].update(master_weights)
272+
# sync_state_dict.update(state_dict)
273+
# state_dict = sync_state_dict
274+
# logger.info("broadcast_moe_optimizer done")
275+
# return state_dict
276+
277+
278+
def broadcast_moe_optimizer(state_dict):
279+
280+
try:
281+
hcg = fleet.get_hybrid_communicate_group()
282+
dp_group = hcg.get_data_parallel_group()
283+
src_rank = hcg.get_data_parallel_group_src_rank()
284+
process_rank = paddle.distributed.get_rank()
285+
data_parallel_rank = hcg.get_data_parallel_rank()
286+
# Don't broadcast optimizer for dp rank is 1.
287+
if dp_group.nranks <= 1:
288+
return state_dict
289+
except:
290+
dp_group = None
291+
src_rank = 0
292+
data_parallel_rank = 0
293+
process_rank = paddle.distributed.get_rank()
294+
295+
def _broadcast_moe_optimizer_state(state_dict):
296+
# boardcast_keys
297+
base_state_dict = {"master_weights": {}}
298+
buf = [
299+
{i: j.shape for i, j in state_dict.items() if i not in ["master_weights", "LR_Scheduler"]},
300+
{i: j.shape for i, j in state_dict["master_weights"].items()},
301+
{"LR_Scheduler": state_dict.get("LR_Scheduler", {})},
302+
]
303+
304+
dist.broadcast_object_list(buf, src=src_rank, group=dp_group)
305+
# logger.info(f"moe-optimizer-gather-keys{buf}")
306+
for k, s in buf[0].items():
307+
v = state_dict.get(k, paddle.zeros(s, "float32")).cuda()
308+
v.name = k
309+
# k = k.replace("_fp32_master_0", "")
310+
dist.broadcast(v, src=src_rank, group=dp_group)
311+
logger.info(f"broadcast moe optimizer {k} from {src_rank}")
312+
base_state_dict[k] = v.cpu()
313+
for k, s in buf[1].items():
314+
v = state_dict["master_weights"].get(k, paddle.zeros(s, "float32")).cuda()
315+
v.name = k
316+
dist.broadcast(v, src=src_rank, group=dp_group)
317+
logger.info(f"broadcast moe optimizer-master_weights {k} from {src_rank}")
318+
base_state_dict["master_weights"][k] = v.cpu()
319+
base_state_dict.update(buf[2])
320+
return base_state_dict
321+
322+
base_state_dict = _broadcast_moe_optimizer_state(state_dict)
323+
if data_parallel_rank > 0:
324+
master_weight = state_dict.pop("master_weights", {})
325+
base_state_dict.update(state_dict)
326+
if master_weight:
327+
if "master_weights" in base_state_dict:
328+
base_state_dict["master_weights"].update(master_weight)
329+
else:
330+
base_state_dict["master_weights"] = master_weight
331+
state_dict = base_state_dict
332+
del base_state_dict
333+
return state_dict

paddlenlp/trainer/utils/reshard/common.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,16 @@ def _opt_name_to_tname(tensor_names, opt_names):
266266
all_names.extend(opt_names)
267267
all_names.sort()
268268
pre_t_name = ""
269+
suffix = [
270+
"_fp32_master_0_beta1_pow_acc_0",
271+
"_fp32_master_0_beta2_pow_acc_0",
272+
"_fp32_master_0_moment1_0",
273+
"_fp32_master_0_moment2_0",
274+
"_beta1_pow_acc_0",
275+
"_beta2_pow_acc_0",
276+
"_moment1_0",
277+
"_moment2_0",
278+
]
269279
opt_to_t = {}
270280
for n in all_names:
271281
if n in tensor_names:
@@ -274,6 +284,24 @@ def _opt_name_to_tname(tensor_names, opt_names):
274284
else:
275285
assert pre_t_name
276286
opt_to_t[n] = pre_t_name
287+
288+
for t in opt_names:
289+
_find = False
290+
for s in suffix:
291+
if t.endswith(s):
292+
logger.info(f"{t}-{t[:-len(s)]}--{t[:-len(s)] in tensor_names}")
293+
opt_to_t[t] = t[:-len(s)]
294+
_find = True
295+
break
296+
assert _find
297+
# opt_to_t = {}
298+
# for n in all_names:
299+
# if n in tensor_names:
300+
# # we get a param
301+
# pre_t_name = n
302+
# else:
303+
# assert pre_t_name
304+
# opt_to_t[n] = pre_t_name
277305
return opt_to_t
278306

279307
if structure_name_mapping is not None:
@@ -291,7 +319,7 @@ def _opt_name_to_tname(tensor_names, opt_names):
291319
(self._model_weights, model_weights_tmp) = (model_weights_tmp, self._model_weights)
292320
for k in list(model_weights_tmp.keys()):
293321
t_name = structure_name_mapping[k]
294-
self._model_weights[(k, t_name)] = model_weights_tmp[k].cpu()
322+
self._model_weights[(k, t_name)] = paddle.to_tensor(model_weights_tmp[k]).cpu()
295323
del model_weights_tmp[k]
296324

297325
# opt

0 commit comments

Comments
 (0)