@@ -945,7 +945,8 @@ def _inner_training_loop(
945
945
((step_control + 1 ) % args .gradient_accumulation_steps != 0 )
946
946
and availiable_no_sync
947
947
and args ._no_sync_in_gradient_accumulation
948
- ) or (args .recompute and availiable_no_sync )
948
+ ) or (args .recompute and availiable_no_sync
949
+ ) or (args .use_moe and availiable_no_sync )
949
950
# sharding
950
951
# stage1. the same as ddp
951
952
# stage2. manualy collect gradient on dp group
@@ -965,6 +966,14 @@ def _inner_training_loop(
965
966
966
967
tr_loss += tr_loss_step
967
968
969
+ def fused_allreduce_gradients_no_sync (paramlist , hcg ):
970
+ paramlist = list (paramlist )
971
+ nonmoe_list = [p for p in paramlist if not getattr (p , "no_sync" , False )]
972
+ moelist = [p for p in paramlist if getattr (p , "no_sync" , False )]
973
+ if moelist and not self .args .use_moe :
974
+ logger .warning ("found `no sync` param when `use_moe=False`" )
975
+ fused_allreduce_gradients (nonmoe_list , hcg )
976
+
968
977
if (step_control + 1 ) % args .gradient_accumulation_steps == 0 or (
969
978
# last step in epoch but step is always smaller than gradient_accumulation_steps
970
979
steps_in_epoch <= args .gradient_accumulation_steps
@@ -983,12 +992,12 @@ def _inner_training_loop(
983
992
984
993
# Case 1: Use recompute and dp / sharding stage1,
985
994
# manualy collect gradient for dp.
986
- if args .recompute and availiable_no_sync :
987
- fused_allreduce_gradients (list (model .parameters ()), None )
995
+ if ( args .recompute or args . use_moe ) and availiable_no_sync :
996
+ fused_allreduce_gradients_no_sync (list (model .parameters ()), None )
988
997
989
998
# Case 2: hack dp with master_grad
990
- if dp_master_grad and not ( args . recompute and availiable_no_sync ) :
991
- fused_allreduce_gradients (list (model .parameters ()), None )
999
+ elif dp_master_grad :
1000
+ fused_allreduce_gradients_no_sync (list (model .parameters ()), None )
992
1001
993
1002
# Pipeline parallel mode, handle gradient reduce here to overlap
994
1003
pipeline_parallel_config = (
@@ -1007,8 +1016,7 @@ def _inner_training_loop(
1007
1016
self .optimizer ._inner_opt .reduce_gradients (list (parameters_list ), self .optimizer ._hcg )
1008
1017
1009
1018
if self .optimizer ._dp_enable or getattr (self .optimizer , "_sep_enable" , False ):
1010
- fused_allreduce_gradients (list (parameters_list ), self .optimizer ._hcg )
1011
-
1019
+ fused_allreduce_gradients_no_sync (list (parameters_list ), self .optimizer ._hcg )
1012
1020
self .timers and self .timers ("all-reduce" ).stop ()
1013
1021
self .timers and self .timers ("optimizer-step" ).start ()
1014
1022
@@ -1028,7 +1036,9 @@ def _inner_training_loop(
1028
1036
)
1029
1037
optimizer_was_run = True
1030
1038
if self .do_grad_scaling :
1031
- scale_before = paddle .assign (self .scaler ._scale )
1039
+ if args .pipeline_parallel_degree > 1 :
1040
+ assert not self .args .use_moe , "pipline moe not work under fp16"
1041
+ scale_before = self .scaler ._scale .numpy ()
1032
1042
self .scaler .step (self .optimizer )
1033
1043
self .scaler .update ()
1034
1044
scale_after = self .scaler ._scale
@@ -2042,7 +2052,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
2042
2052
2043
2053
model .train ()
2044
2054
inputs = self ._prepare_inputs (inputs )
2045
-
2055
+ self . timers and self . timers ( f"forward-acc- { self . _cur_acc_step } " ). start ()
2046
2056
with self .autocast_smart_context_manager ():
2047
2057
loss = self .compute_loss (model , inputs )
2048
2058
@@ -2053,7 +2063,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
2053
2063
self .scaler .scale (loss ).backward ()
2054
2064
else :
2055
2065
loss .backward ()
2056
-
2066
+ self . timers and self . timers ( f"backward-acc- { self . _cur_acc_step } " ). stop ()
2057
2067
return loss .detach ()
2058
2068
2059
2069
def training_pipeline_step (self , model : nn .Layer , inputs : Dict [str , Union [paddle .Tensor , Any ]]) -> paddle .Tensor :
@@ -2142,6 +2152,18 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
2142
2152
if self .args .should_save_model_state and self .args .should_save :
2143
2153
# For ckpt integrity
2144
2154
paddle .save (self .state .global_step , os .path .join (output_dir , ".model_done" ))
2155
+ def _save_moe_weights (
2156
+ self ,
2157
+ output_dir : Optional [str ] = None ,
2158
+ merge_tensor_parallel : Optional [bool ] = False ,):
2159
+ # save moe optimizer and model state # TODO 默认为冗余存储
2160
+
2161
+ self ._save (output_dir = output_dir , merge_tensor_parallel = merge_tensor_parallel )
2162
+ optimizer_name = _add_variant (OPTIMIZER_NAME , self .args .optimizer_name_suffix )
2163
+ saved_signal_path = os .path .join (output_dir , f"saved_signal_{ dist .get_rank ()} " )
2164
+ paddle .save (self .optimizer .state_dict (), os .path .join (output_dir , optimizer_name ))
2165
+ with open (saved_signal_path , mode = "w+" ) as f :
2166
+ f .write ("1" )
2145
2167
2146
2168
def _save_checkpoint (self , model , metrics = None ):
2147
2169
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
@@ -2245,6 +2267,8 @@ def _save_checkpoint(self, model, metrics=None):
2245
2267
os .makedirs (output_dir , exist_ok = True )
2246
2268
paddle .save (rng_states , os .path .join (output_dir , "rng_state.pth" ))
2247
2269
2270
+ if self .args .use_moe and self .args .data_parallel_rank > 0 :
2271
+ self ._save_moe_weights (output_dir )
2248
2272
# Maybe delete some older checkpoints.
2249
2273
# For hybrid parallel training, the checkpoint files maybe on different node.
2250
2274
need_to_rotate_checkpoints = False
0 commit comments