143
143
from .utils import reshard as reshard_util
144
144
from .utils .helper import ( # nested_truncate,
145
145
broadcast_dp_optimizer ,
146
+ broadcast_moe_optimizer ,
146
147
distributed_concat ,
147
148
distributed_file ,
148
149
distributed_isfile ,
@@ -565,7 +566,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
565
566
)
566
567
self .model .set_state_dict (state_dict )
567
568
else :
568
- if resume_from_checkpoint is not None and self .args .dataset_rank == 0 :
569
+ if resume_from_checkpoint is not None and ( self .args .dataset_rank == 0 or self . args . use_expert_parallel ) :
569
570
570
571
weights_file = os .path .join (
571
572
resume_from_checkpoint , _add_variant (weight_name , self .args .weight_name_suffix )
@@ -930,22 +931,17 @@ def _inner_training_loop(
930
931
self .control = self .callback_handler .on_step_begin (args , self .state , self .control )
931
932
self .timers and self .timers ("forward-backward" ).start ()
932
933
933
- dp_enabled = (
934
- self .args .data_parallel_degree > 1 if self .args .use_hybrid_parallel else args .local_rank != - 1
935
- )
936
- forbidden_no_sync = False
937
934
# stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API
938
935
# hybrid_parallel (tp or pp or sharding stage 1) should not no_sync
939
- if self .args .use_hybrid_parallel :
940
- forbidden_no_sync = True
941
-
942
- availiable_no_sync = dp_enabled and not forbidden_no_sync
943
-
936
+ availiable_no_sync = hasattr (model , "no_sync" )
944
937
is_no_sync = (
945
- ((step_control + 1 ) % args .gradient_accumulation_steps != 0 )
946
- and availiable_no_sync
947
- and args ._no_sync_in_gradient_accumulation
948
- ) or (args .recompute and availiable_no_sync )
938
+ (
939
+ ((step_control + 1 ) % args .gradient_accumulation_steps != 0 )
940
+ and args ._no_sync_in_gradient_accumulation
941
+ )
942
+ or args .recompute
943
+ or args .use_expert_parallel
944
+ ) and availiable_no_sync
949
945
# sharding
950
946
# stage1. the same as ddp
951
947
# stage2. manualy collect gradient on dp group
@@ -965,6 +961,14 @@ def _inner_training_loop(
965
961
966
962
tr_loss += tr_loss_step
967
963
964
+ def fused_allreduce_gradients_no_sync (paramlist , hcg ):
965
+ paramlist = list (paramlist )
966
+ nonmoe_list = [p for p in paramlist if not getattr (p , "no_sync" , False )]
967
+ moelist = [p for p in paramlist if getattr (p , "no_sync" , False )]
968
+ if moelist and not self .args .use_expert_parallel :
969
+ logger .warning ("found `no sync` param when `use_expert_parallel=False`" )
970
+ fused_allreduce_gradients (nonmoe_list , hcg )
971
+
968
972
if (step_control + 1 ) % args .gradient_accumulation_steps == 0 or (
969
973
# last step in epoch but step is always smaller than gradient_accumulation_steps
970
974
steps_in_epoch <= args .gradient_accumulation_steps
@@ -983,12 +987,12 @@ def _inner_training_loop(
983
987
984
988
# Case 1: Use recompute and dp / sharding stage1,
985
989
# manualy collect gradient for dp.
986
- if args .recompute and availiable_no_sync :
987
- fused_allreduce_gradients (list (model .parameters ()), None )
990
+ if ( args .recompute or args . use_expert_parallel ) and availiable_no_sync :
991
+ fused_allreduce_gradients_no_sync (list (model .parameters ()), None )
988
992
989
993
# Case 2: hack dp with master_grad
990
- if dp_master_grad and not ( args . recompute and availiable_no_sync ) :
991
- fused_allreduce_gradients (list (model .parameters ()), None )
994
+ elif dp_master_grad :
995
+ fused_allreduce_gradients_no_sync (list (model .parameters ()), None )
992
996
993
997
# Pipeline parallel mode, handle gradient reduce here to overlap
994
998
pipeline_parallel_config = (
@@ -1007,8 +1011,7 @@ def _inner_training_loop(
1007
1011
self .optimizer ._inner_opt .reduce_gradients (list (parameters_list ), self .optimizer ._hcg )
1008
1012
1009
1013
if self .optimizer ._dp_enable or getattr (self .optimizer , "_sep_enable" , False ):
1010
- fused_allreduce_gradients (list (parameters_list ), self .optimizer ._hcg )
1011
-
1014
+ fused_allreduce_gradients_no_sync (list (parameters_list ), self .optimizer ._hcg )
1012
1015
self .timers and self .timers ("all-reduce" ).stop ()
1013
1016
self .timers and self .timers ("optimizer-step" ).start ()
1014
1017
@@ -1028,6 +1031,8 @@ def _inner_training_loop(
1028
1031
)
1029
1032
optimizer_was_run = True
1030
1033
if self .do_grad_scaling :
1034
+ if args .pipeline_parallel_degree > 1 :
1035
+ assert not self .args .use_expert_parallel , "pipeline moe not work under fp16"
1031
1036
scale_before = paddle .assign (self .scaler ._scale )
1032
1037
self .scaler .step (self .optimizer )
1033
1038
self .scaler .update ()
@@ -2042,7 +2047,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
2042
2047
2043
2048
model .train ()
2044
2049
inputs = self ._prepare_inputs (inputs )
2045
-
2046
2050
with self .autocast_smart_context_manager ():
2047
2051
loss = self .compute_loss (model , inputs )
2048
2052
@@ -2053,7 +2057,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
2053
2057
self .scaler .scale (loss ).backward ()
2054
2058
else :
2055
2059
loss .backward ()
2056
-
2057
2060
return loss .detach ()
2058
2061
2059
2062
def training_pipeline_step (self , model : nn .Layer , inputs : Dict [str , Union [paddle .Tensor , Any ]]) -> paddle .Tensor :
@@ -2143,6 +2146,26 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
2143
2146
# For ckpt integrity
2144
2147
paddle .save (self .state .global_step , os .path .join (output_dir , ".model_done" ))
2145
2148
2149
+ def _filter_moe_no_sync_optimizer_params (self ):
2150
+ """
2151
+ filter optimizer params which should not sync
2152
+ """
2153
+ state_dict = self .model .state_dict ()
2154
+ optimzier_state_dict = self .optimizer .state_dict ()
2155
+ filter_optimzier_state_dict = OrderedDict ()
2156
+ param_names_in_master_weights = list (optimzier_state_dict ["master_weights" ].keys ()) if self .args .bf16 else []
2157
+ filter_optimzier_state_dict ["master_weights" ] = OrderedDict ()
2158
+ for k , v in state_dict .items ():
2159
+ if getattr (v , "no_sync" , False ):
2160
+ if v .name in param_names_in_master_weights :
2161
+ filter_optimzier_state_dict ["master_weights" ][v .name ] = optimzier_state_dict ["master_weights" ][
2162
+ v .name
2163
+ ]
2164
+ for op_k , op_v in optimzier_state_dict .items ():
2165
+ if op_k .startswith (v .name ):
2166
+ filter_optimzier_state_dict [op_k ] = op_v
2167
+ return filter_optimzier_state_dict
2168
+
2146
2169
def _save_checkpoint (self , model , metrics = None ):
2147
2170
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
2148
2171
self .runtime_timer .start ("checkpoint saving time" )
@@ -2165,7 +2188,7 @@ def _save_checkpoint(self, model, metrics=None):
2165
2188
optimizer_name = _add_variant (OPTIMIZER_NAME , self .args .optimizer_name_suffix )
2166
2189
2167
2190
if self .args .use_hybrid_parallel :
2168
- if self .dp_group .rank <= 0 :
2191
+ if self .dp_group .rank <= 0 or self . args . use_expert_parallel :
2169
2192
os .makedirs (output_dir , exist_ok = True )
2170
2193
logger .info ("Saving optimizer files." )
2171
2194
if self .args .unified_checkpoint :
@@ -2177,12 +2200,18 @@ def _save_checkpoint(self, model, metrics=None):
2177
2200
safe_serialization = True ,
2178
2201
)
2179
2202
else :
2180
- self ._save_ckpt_func (
2181
- self .optimizer .state_dict (),
2182
- os .path .join (output_dir , optimizer_name ),
2183
- )
2203
+ if self .dp_group .rank > 0 : # this should only work for MoE saving
2204
+ self ._save_ckpt_func (
2205
+ self ._filter_moe_no_sync_optimizer_params (),
2206
+ os .path .join (output_dir , optimizer_name ),
2207
+ )
2208
+ else :
2209
+ self ._save_ckpt_func (
2210
+ self .optimizer .state_dict (),
2211
+ os .path .join (output_dir , optimizer_name ),
2212
+ )
2184
2213
2185
- if self .args .should_save :
2214
+ if self .args .should_save or self . args . use_expert_parallel :
2186
2215
if not self .args .use_hybrid_parallel :
2187
2216
logger .info ("Saving optimizer files." )
2188
2217
if self .args .unified_checkpoint :
@@ -2194,7 +2223,12 @@ def _save_checkpoint(self, model, metrics=None):
2194
2223
safe_serialization = True ,
2195
2224
)
2196
2225
else :
2197
- self ._save_ckpt_func (self .optimizer .state_dict (), os .path .join (output_dir , OPTIMIZER_NAME ))
2226
+ if self .dp_group .rank > 0 :
2227
+ self ._save_ckpt_func (
2228
+ self ._filter_moe_no_sync_optimizer_params (), os .path .join (output_dir , OPTIMIZER_NAME )
2229
+ )
2230
+ else :
2231
+ self ._save_ckpt_func (self .optimizer .state_dict (), os .path .join (output_dir , OPTIMIZER_NAME ))
2198
2232
2199
2233
# FIXME: maybe only save one copy
2200
2234
paddle .save (self .lr_scheduler .state_dict (), os .path .join (output_dir , SCHEDULER_NAME ))
@@ -2452,7 +2486,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
2452
2486
logger .info ("Loading checkpoint, the next checkpoint will be saved as unified checkpoint" )
2453
2487
2454
2488
if not use_unified_checkpoint :
2455
- if self .args .data_parallel_rank == 0 :
2489
+ if self .args .data_parallel_rank == 0 or self . args . use_expert_parallel :
2456
2490
optimizer_name = _add_variant (OPTIMIZER_NAME , self .args .optimizer_name_suffix )
2457
2491
path = os .path .join (checkpoint , optimizer_name )
2458
2492
if os .path .isfile (path ):
@@ -2476,7 +2510,11 @@ def _load_optimizer_and_scheduler(self, checkpoint):
2476
2510
# broadcast optimizer state in dp group
2477
2511
if self .args .local_rank != - 1 :
2478
2512
dist .barrier ()
2479
- opt_state_dict = broadcast_dp_optimizer (opt_state_dict )
2513
+ if self .args .use_expert_parallel :
2514
+ opt_state_dict = broadcast_moe_optimizer (opt_state_dict )
2515
+ else :
2516
+ if not self .args .should_load_sharding_stage1_model :
2517
+ opt_state_dict = broadcast_dp_optimizer (opt_state_dict )
2480
2518
2481
2519
if opt_state_dict is not None :
2482
2520
# Load in optimizer and scheduler states
0 commit comments