143
143
from .utils import reshard as reshard_util
144
144
from .utils .helper import ( # nested_truncate,
145
145
broadcast_dp_optimizer ,
146
+ broadcast_moe_optimizer ,
146
147
distributed_concat ,
147
148
distributed_file ,
148
149
distributed_isfile ,
@@ -930,22 +931,14 @@ def _inner_training_loop(
930
931
self .control = self .callback_handler .on_step_begin (args , self .state , self .control )
931
932
self .timers and self .timers ("forward-backward" ).start ()
932
933
933
- dp_enabled = (
934
- self .args .data_parallel_degree > 1 if self .args .use_hybrid_parallel else args .local_rank != - 1
935
- )
936
- forbidden_no_sync = False
937
934
# stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API
938
935
# hybrid_parallel (tp or pp or sharding stage 1) should not no_sync
939
- if self .args .use_hybrid_parallel :
940
- forbidden_no_sync = True
941
-
942
- availiable_no_sync = dp_enabled and not forbidden_no_sync
943
-
936
+ availiable_no_sync = hasattr (model , "no_sync" )
944
937
is_no_sync = (
945
- ((step_control + 1 ) % args .gradient_accumulation_steps != 0 )
946
- and availiable_no_sync
947
- and args ._no_sync_in_gradient_accumulation
948
- ) or ( args . recompute and availiable_no_sync )
938
+ ((( step + 1 ) % args .gradient_accumulation_steps != 0 ) and args . _no_sync_in_gradient_accumulation )
939
+ or args . recompute
940
+ or args .use_moe
941
+ ) and availiable_no_sync
949
942
# sharding
950
943
# stage1. the same as ddp
951
944
# stage2. manualy collect gradient on dp group
@@ -965,6 +958,14 @@ def _inner_training_loop(
965
958
966
959
tr_loss += tr_loss_step
967
960
961
+ def fused_allreduce_gradients_no_sync (paramlist , hcg ):
962
+ paramlist = list (paramlist )
963
+ nonmoe_list = [p for p in paramlist if not getattr (p , "no_sync" , False )]
964
+ moelist = [p for p in paramlist if getattr (p , "no_sync" , False )]
965
+ if moelist and not self .args .use_moe :
966
+ logger .warning ("found `no sync` param when `use_moe=False`" )
967
+ fused_allreduce_gradients (nonmoe_list , hcg )
968
+
968
969
if (step_control + 1 ) % args .gradient_accumulation_steps == 0 or (
969
970
# last step in epoch but step is always smaller than gradient_accumulation_steps
970
971
steps_in_epoch <= args .gradient_accumulation_steps
@@ -983,12 +984,12 @@ def _inner_training_loop(
983
984
984
985
# Case 1: Use recompute and dp / sharding stage1,
985
986
# manualy collect gradient for dp.
986
- if args .recompute and availiable_no_sync :
987
- fused_allreduce_gradients (list (model .parameters ()), None )
987
+ if ( args .recompute or args . use_moe ) and availiable_no_sync :
988
+ fused_allreduce_gradients_no_sync (list (model .parameters ()), None )
988
989
989
990
# Case 2: hack dp with master_grad
990
- if dp_master_grad and not ( args . recompute and availiable_no_sync ) :
991
- fused_allreduce_gradients (list (model .parameters ()), None )
991
+ elif dp_master_grad :
992
+ fused_allreduce_gradients_no_sync (list (model .parameters ()), None )
992
993
993
994
# Pipeline parallel mode, handle gradient reduce here to overlap
994
995
pipeline_parallel_config = (
@@ -1007,8 +1008,7 @@ def _inner_training_loop(
1007
1008
self .optimizer ._inner_opt .reduce_gradients (list (parameters_list ), self .optimizer ._hcg )
1008
1009
1009
1010
if self .optimizer ._dp_enable or getattr (self .optimizer , "_sep_enable" , False ):
1010
- fused_allreduce_gradients (list (parameters_list ), self .optimizer ._hcg )
1011
-
1011
+ fused_allreduce_gradients_no_sync (list (parameters_list ), self .optimizer ._hcg )
1012
1012
self .timers and self .timers ("all-reduce" ).stop ()
1013
1013
self .timers and self .timers ("optimizer-step" ).start ()
1014
1014
@@ -1028,7 +1028,9 @@ def _inner_training_loop(
1028
1028
)
1029
1029
optimizer_was_run = True
1030
1030
if self .do_grad_scaling :
1031
- scale_before = paddle .assign (self .scaler ._scale )
1031
+ if args .pipeline_parallel_degree > 1 :
1032
+ assert not self .args .use_moe , "pipline moe not work under fp16"
1033
+ scale_before = self .scaler ._scale .numpy ()
1032
1034
self .scaler .step (self .optimizer )
1033
1035
self .scaler .update ()
1034
1036
scale_after = self .scaler ._scale
@@ -2042,7 +2044,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
2042
2044
2043
2045
model .train ()
2044
2046
inputs = self ._prepare_inputs (inputs )
2045
-
2046
2047
with self .autocast_smart_context_manager ():
2047
2048
loss = self .compute_loss (model , inputs )
2048
2049
@@ -2053,7 +2054,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
2053
2054
self .scaler .scale (loss ).backward ()
2054
2055
else :
2055
2056
loss .backward ()
2056
-
2057
2057
return loss .detach ()
2058
2058
2059
2059
def training_pipeline_step (self , model : nn .Layer , inputs : Dict [str , Union [paddle .Tensor , Any ]]) -> paddle .Tensor :
@@ -2143,6 +2143,20 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
2143
2143
# For ckpt integrity
2144
2144
paddle .save (self .state .global_step , os .path .join (output_dir , ".model_done" ))
2145
2145
2146
+ def _save_moe_weights (
2147
+ self ,
2148
+ output_dir : Optional [str ] = None ,
2149
+ merge_tensor_parallel : Optional [bool ] = False ,
2150
+ ):
2151
+ # save moe optimizer and model state # TODO 默认为冗余存储
2152
+
2153
+ self ._save (output_dir = output_dir , merge_tensor_parallel = merge_tensor_parallel )
2154
+ optimizer_name = _add_variant (OPTIMIZER_NAME , self .args .optimizer_name_suffix )
2155
+ saved_signal_path = os .path .join (output_dir , f"saved_signal_{ dist .get_rank ()} " )
2156
+ paddle .save (self .optimizer .state_dict (), os .path .join (output_dir , optimizer_name ))
2157
+ with open (saved_signal_path , mode = "w+" ) as f :
2158
+ f .write ("1" )
2159
+
2146
2160
def _save_checkpoint (self , model , metrics = None ):
2147
2161
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
2148
2162
self .runtime_timer .start ("checkpoint saving time" )
@@ -2245,6 +2259,8 @@ def _save_checkpoint(self, model, metrics=None):
2245
2259
os .makedirs (output_dir , exist_ok = True )
2246
2260
paddle .save (rng_states , os .path .join (output_dir , "rng_state.pth" ))
2247
2261
2262
+ if self .args .use_moe and self .args .data_parallel_rank > 0 :
2263
+ self ._save_moe_weights (output_dir )
2248
2264
# Maybe delete some older checkpoints.
2249
2265
# For hybrid parallel training, the checkpoint files maybe on different node.
2250
2266
need_to_rotate_checkpoints = False
@@ -2476,7 +2492,10 @@ def _load_optimizer_and_scheduler(self, checkpoint):
2476
2492
# broadcast optimizer state in dp group
2477
2493
if self .args .local_rank != - 1 :
2478
2494
dist .barrier ()
2479
- opt_state_dict = broadcast_dp_optimizer (opt_state_dict )
2495
+ if not self .args .use_moe :
2496
+ opt_state_dict = broadcast_dp_optimizer (opt_state_dict )
2497
+ else :
2498
+ opt_state_dict = broadcast_moe_optimizer (opt_state_dict )
2480
2499
2481
2500
if opt_state_dict is not None :
2482
2501
# Load in optimizer and scheduler states
0 commit comments