Skip to content

Commit 939c0d2

Browse files
committed
[fea] moe support
1 parent 0cd8fe7 commit 939c0d2

File tree

5 files changed

+167
-33
lines changed

5 files changed

+167
-33
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@
143143
from .utils import reshard as reshard_util
144144
from .utils.helper import ( # nested_truncate,
145145
broadcast_dp_optimizer,
146+
broadcast_moe_optimizer,
146147
distributed_concat,
147148
distributed_file,
148149
distributed_isfile,
@@ -930,22 +931,14 @@ def _inner_training_loop(
930931
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
931932
self.timers and self.timers("forward-backward").start()
932933

933-
dp_enabled = (
934-
self.args.data_parallel_degree > 1 if self.args.use_hybrid_parallel else args.local_rank != -1
935-
)
936-
forbidden_no_sync = False
937934
# stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API
938935
# hybrid_parallel (tp or pp or sharding stage 1) should not no_sync
939-
if self.args.use_hybrid_parallel:
940-
forbidden_no_sync = True
941-
942-
availiable_no_sync = dp_enabled and not forbidden_no_sync
943-
936+
availiable_no_sync = hasattr(model, "no_sync")
944937
is_no_sync = (
945-
((step_control + 1) % args.gradient_accumulation_steps != 0)
946-
and availiable_no_sync
947-
and args._no_sync_in_gradient_accumulation
948-
) or (args.recompute and availiable_no_sync)
938+
(((step + 1) % args.gradient_accumulation_steps != 0) and args._no_sync_in_gradient_accumulation)
939+
or args.recompute
940+
or args.use_moe
941+
) and availiable_no_sync
949942
# sharding
950943
# stage1. the same as ddp
951944
# stage2. manualy collect gradient on dp group
@@ -965,6 +958,14 @@ def _inner_training_loop(
965958

966959
tr_loss += tr_loss_step
967960

961+
def fused_allreduce_gradients_no_sync(paramlist, hcg):
962+
paramlist = list(paramlist)
963+
nonmoe_list = [p for p in paramlist if not getattr(p, "no_sync", False)]
964+
moelist = [p for p in paramlist if getattr(p, "no_sync", False)]
965+
if moelist and not self.args.use_moe:
966+
logger.warning("found `no sync` param when `use_moe=False`")
967+
fused_allreduce_gradients(nonmoe_list, hcg)
968+
968969
if (step_control + 1) % args.gradient_accumulation_steps == 0 or (
969970
# last step in epoch but step is always smaller than gradient_accumulation_steps
970971
steps_in_epoch <= args.gradient_accumulation_steps
@@ -983,12 +984,12 @@ def _inner_training_loop(
983984

984985
# Case 1: Use recompute and dp / sharding stage1,
985986
# manualy collect gradient for dp.
986-
if args.recompute and availiable_no_sync:
987-
fused_allreduce_gradients(list(model.parameters()), None)
987+
if (args.recompute or args.use_moe) and availiable_no_sync:
988+
fused_allreduce_gradients_no_sync(list(model.parameters()), None)
988989

989990
# Case 2: hack dp with master_grad
990-
if dp_master_grad and not (args.recompute and availiable_no_sync):
991-
fused_allreduce_gradients(list(model.parameters()), None)
991+
elif dp_master_grad:
992+
fused_allreduce_gradients_no_sync(list(model.parameters()), None)
992993

993994
# Pipeline parallel mode, handle gradient reduce here to overlap
994995
pipeline_parallel_config = (
@@ -1007,8 +1008,7 @@ def _inner_training_loop(
10071008
self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg)
10081009

10091010
if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False):
1010-
fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg)
1011-
1011+
fused_allreduce_gradients_no_sync(list(parameters_list), self.optimizer._hcg)
10121012
self.timers and self.timers("all-reduce").stop()
10131013
self.timers and self.timers("optimizer-step").start()
10141014

@@ -1028,7 +1028,9 @@ def _inner_training_loop(
10281028
)
10291029
optimizer_was_run = True
10301030
if self.do_grad_scaling:
1031-
scale_before = paddle.assign(self.scaler._scale)
1031+
if args.pipeline_parallel_degree > 1:
1032+
assert not self.args.use_moe, "pipline moe not work under fp16"
1033+
scale_before = self.scaler._scale.numpy()
10321034
self.scaler.step(self.optimizer)
10331035
self.scaler.update()
10341036
scale_after = self.scaler._scale
@@ -2042,7 +2044,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
20422044

20432045
model.train()
20442046
inputs = self._prepare_inputs(inputs)
2045-
20462047
with self.autocast_smart_context_manager():
20472048
loss = self.compute_loss(model, inputs)
20482049

@@ -2053,7 +2054,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
20532054
self.scaler.scale(loss).backward()
20542055
else:
20552056
loss.backward()
2056-
20572057
return loss.detach()
20582058

20592059
def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
@@ -2143,6 +2143,20 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
21432143
# For ckpt integrity
21442144
paddle.save(self.state.global_step, os.path.join(output_dir, ".model_done"))
21452145

2146+
def _save_moe_weights(
2147+
self,
2148+
output_dir: Optional[str] = None,
2149+
merge_tensor_parallel: Optional[bool] = False,
2150+
):
2151+
# save moe optimizer and model state # TODO 默认为冗余存储
2152+
2153+
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)
2154+
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
2155+
saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}")
2156+
paddle.save(self.optimizer.state_dict(), os.path.join(output_dir, optimizer_name))
2157+
with open(saved_signal_path, mode="w+") as f:
2158+
f.write("1")
2159+
21462160
def _save_checkpoint(self, model, metrics=None):
21472161
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
21482162
self.runtime_timer.start("checkpoint saving time")
@@ -2245,6 +2259,8 @@ def _save_checkpoint(self, model, metrics=None):
22452259
os.makedirs(output_dir, exist_ok=True)
22462260
paddle.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
22472261

2262+
if self.args.use_moe and self.args.data_parallel_rank > 0:
2263+
self._save_moe_weights(output_dir)
22482264
# Maybe delete some older checkpoints.
22492265
# For hybrid parallel training, the checkpoint files maybe on different node.
22502266
need_to_rotate_checkpoints = False
@@ -2476,7 +2492,10 @@ def _load_optimizer_and_scheduler(self, checkpoint):
24762492
# broadcast optimizer state in dp group
24772493
if self.args.local_rank != -1:
24782494
dist.barrier()
2479-
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)
2495+
if not self.args.use_moe:
2496+
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)
2497+
else:
2498+
opt_state_dict = broadcast_moe_optimizer(opt_state_dict)
24802499

24812500
if opt_state_dict is not None:
24822501
# Load in optimizer and scheduler states

paddlenlp/trainer/training_args.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,10 @@ class TrainingArguments:
803803
default=False,
804804
metadata={"help": "whether to run distributed training in auto parallel mode"},
805805
)
806+
use_moe: Optional[bool] = field(
807+
default=False,
808+
metadata={"help": "开启moe训练"},
809+
)
806810

807811
def __post_init__(self):
808812
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
@@ -1149,6 +1153,8 @@ def is_segment_parallel_supported():
11491153
order = ["dp", "sharding", "pp", "sep", "mp"]
11501154
else:
11511155
order = ["dp", "sharding", "pp", "mp"]
1156+
if self.use_moe:
1157+
order = order[1:-1] + ["dp", "mp"]
11521158

11531159
if is_segment_parallel_supported():
11541160
hybrid_configs = {
@@ -1640,8 +1646,12 @@ def optimizer_name_suffix(self):
16401646
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
16411647
if self.sharding_parallel_degree > 1:
16421648
name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree))
1649+
if self.use_moe:
1650+
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
16431651
return "_".join(name)
16441652
else:
1653+
if self.use_moe:
1654+
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)
16451655
return None
16461656

16471657
@property
@@ -1652,12 +1662,16 @@ def weight_name_suffix(self):
16521662
name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree))
16531663
if self.pipeline_parallel_degree > 1:
16541664
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
1665+
if self.use_moe:
1666+
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
16551667
return "_".join(name)
16561668

16571669
else:
1670+
if self.use_moe:
1671+
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)
16581672
return None
16591673

1660-
def sharded_name_suffix(self, shard_id=None, pp_id=None):
1674+
def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):
16611675
if self.use_hybrid_parallel:
16621676
name = []
16631677
if self.tensor_parallel_degree > 1:
@@ -1672,8 +1686,17 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None):
16721686
shard_id = self.sharding_parallel_rank
16731687
assert isinstance(shard_id, int)
16741688
name.append(self._format_name("shard", shard_id, self.sharding_parallel_degree))
1689+
if self.use_moe:
1690+
if moe_id is None:
1691+
moe_id = self.data_parallel_rank
1692+
assert isinstance(moe_id, int)
1693+
name.append(self._format_name("moe", moe_id, self.data_parallel_degree))
16751694
return "_".join(name)
16761695
else:
1696+
if self.use_moe:
1697+
if moe_id is None:
1698+
moe_id = self.data_parallel_rank
1699+
return self._format_name("moe", moe_id, self.data_parallel_degree)
16771700
return None
16781701

16791702
@property

paddlenlp/trainer/utils/helper.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,59 @@ def broadcast_dp_optimizer(state_dict):
226226
state_dict = nested_broadcast_tensor(state_dict, src=src_rank, group=dp_group)
227227

228228
return state_dict
229+
230+
231+
def broadcast_moe_optimizer(state_dict):
232+
233+
try:
234+
hcg = fleet.get_hybrid_communicate_group()
235+
dp_group = hcg.get_data_parallel_group()
236+
src_rank = hcg.get_data_parallel_group_src_rank()
237+
data_parallel_rank = hcg.get_data_parallel_rank()
238+
# Don't broadcast optimizer for dp rank is 1.
239+
if dp_group.nranks <= 1:
240+
return state_dict
241+
except:
242+
dp_group = None
243+
src_rank = 0
244+
data_parallel_rank = 0
245+
246+
def _broadcast_moe_optimizer_state(state_dict):
247+
# boardcast_keys
248+
base_state_dict = {"master_weights": {}}
249+
buf = [
250+
{i: j.shape for i, j in state_dict.items() if i not in ["master_weights", "LR_Scheduler"]},
251+
{i: j.shape for i, j in state_dict["master_weights"].items()},
252+
{"LR_Scheduler": state_dict.get("LR_Scheduler", {})},
253+
]
254+
255+
dist.broadcast_object_list(buf, src=src_rank, group=dp_group)
256+
# logger.info(f"moe-optimizer-gather-keys{buf}")
257+
for k, s in buf[0].items():
258+
v = state_dict.get(k, paddle.zeros(s, "float32")).cuda()
259+
v.name = k
260+
# k = k.replace("_fp32_master_0", "")
261+
dist.broadcast(v, src=src_rank, group=dp_group)
262+
logger.info(f"broadcast moe optimizer {k} from {src_rank}")
263+
base_state_dict[k] = v.cpu()
264+
for k, s in buf[1].items():
265+
v = state_dict["master_weights"].get(k, paddle.zeros(s, "float32")).cuda()
266+
v.name = k
267+
dist.broadcast(v, src=src_rank, group=dp_group)
268+
logger.info(f"broadcast moe optimizer-master_weights {k} from {src_rank}")
269+
base_state_dict["master_weights"][k] = v.cpu()
270+
base_state_dict.update(buf[2])
271+
return base_state_dict
272+
273+
base_state_dict = _broadcast_moe_optimizer_state(state_dict)
274+
if data_parallel_rank > 0:
275+
master_weight = state_dict.pop("master_weights", {})
276+
base_state_dict.update(state_dict)
277+
if master_weight:
278+
if "master_weights" in base_state_dict:
279+
base_state_dict["master_weights"].update(master_weight)
280+
else:
281+
base_state_dict["master_weights"] = master_weight
282+
state_dict = base_state_dict
283+
del base_state_dict
284+
return state_dict

paddlenlp/trainer/utils/reshard/common.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,16 @@ def _opt_name_to_tname(tensor_names, opt_names):
266266
all_names.extend(opt_names)
267267
all_names.sort()
268268
pre_t_name = ""
269+
suffix = [
270+
"_fp32_master_0_beta1_pow_acc_0",
271+
"_fp32_master_0_beta2_pow_acc_0",
272+
"_fp32_master_0_moment1_0",
273+
"_fp32_master_0_moment2_0",
274+
"_beta1_pow_acc_0",
275+
"_beta2_pow_acc_0",
276+
"_moment1_0",
277+
"_moment2_0",
278+
]
269279
opt_to_t = {}
270280
for n in all_names:
271281
if n in tensor_names:
@@ -274,6 +284,24 @@ def _opt_name_to_tname(tensor_names, opt_names):
274284
else:
275285
assert pre_t_name
276286
opt_to_t[n] = pre_t_name
287+
288+
for t in opt_names:
289+
_find = False
290+
for s in suffix:
291+
if t.endswith(s):
292+
logger.info(f"{t}-{t[:-len(s)]}--{t[:-len(s)] in tensor_names}")
293+
opt_to_t[t] = t[: -len(s)]
294+
_find = True
295+
break
296+
assert _find
297+
# opt_to_t = {}
298+
# for n in all_names:
299+
# if n in tensor_names:
300+
# # we get a param
301+
# pre_t_name = n
302+
# else:
303+
# assert pre_t_name
304+
# opt_to_t[n] = pre_t_name
277305
return opt_to_t
278306

279307
if structure_name_mapping is not None:
@@ -291,7 +319,7 @@ def _opt_name_to_tname(tensor_names, opt_names):
291319
(self._model_weights, model_weights_tmp) = (model_weights_tmp, self._model_weights)
292320
for k in list(model_weights_tmp.keys()):
293321
t_name = structure_name_mapping[k]
294-
self._model_weights[(k, t_name)] = model_weights_tmp[k].cpu()
322+
self._model_weights[(k, t_name)] = paddle.to_tensor(model_weights_tmp[k]).cpu()
295323
del model_weights_tmp[k]
296324

297325
# opt

paddlenlp/trainer/utils/sharding_io.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,14 @@ def filter_sharded_params(state_dict, optimizer, sharding_group):
6767
if reshard_util.get_sharding_strategy(optimizer) == reshard_util.SHARDING_STRATEGY_V1:
6868
optimizer = unwrap_optimizer(optimizer, DygraphShardingOptimizer)
6969
for (k, v) in state_dict.items():
70-
assert v.name in optimizer._param2rank
71-
sharded_rank = optimizer._param2rank[v.name]
72-
if sharded_rank != sharding_rank:
73-
continue
74-
filtered_state_dict[k] = v
70+
if v.name in optimizer._param2rank:
71+
sharded_rank = optimizer._param2rank[v.name]
72+
if sharded_rank != sharding_rank:
73+
continue
74+
filtered_state_dict[k] = v
75+
else:
76+
if sharding_rank == 0:
77+
filtered_state_dict[k] = v
7578
else:
7679
optimizer = unwrap_optimizer(optimizer, DygraphShardingOptimizerV2)
7780
parameters = optimizer._parameter_list
@@ -352,7 +355,7 @@ def manipulate_state_dict_and_config(self, model_to_save, merge_tensor_parallel=
352355
)
353356
logger.info(
354357
"param_names_in_master_weights len:{}, bf16 state_dict len:{}, :{}".format(
355-
len(param_names_in_master_weights), len(state_dict), state_dict
358+
len(param_names_in_master_weights), len(state_dict), state_dict.keys()
356359
)
357360
)
358361
return state_dict, config_to_save, weight_name_suffix
@@ -444,12 +447,17 @@ def filter_func(name):
444447

445448
master_weights = reshard_util.all_gather_state_dict(master_weights, filter_func, self.sharding_group)
446449
model_state_dict = self.model.state_dict()
450+
logger.info(f"state-dict-keys: {state_dict.keys()}, nums: {len(state_dict.keys())}")
447451
logger.info("before recover, model_state_dict number: {}".format(len(model_state_dict)))
448452
for key, param in model_state_dict.items():
449453
if param.name in master_weights:
450454
assert param.shape == master_weights[param.name].shape
451-
paddle.assign(master_weights[param.name].cuda(), model_state_dict[key])
452-
455+
paddle.assign(paddle.cast(master_weights[param.name].cuda(), paddle.bfloat16), model_state_dict[key])
456+
elif key in state_dict:
457+
logger.info(f"key: {key} is in state_dict, but not in master_weights")
458+
paddle.assign(state_dict[key], model_state_dict[key])
459+
else:
460+
logger.info(f"key: {key} is not in state_dict and master_weights")
453461
logger.info("after recover, casted model_state_dict number: {}".format(len(model_state_dict)))
454462
state_dict.update(model_state_dict)
455463
return state_dict

0 commit comments

Comments
 (0)