@@ -569,7 +569,12 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
569
569
base_weight_name = weight_name ,
570
570
model_wrapped = self .model_wrapped ,
571
571
)
572
- self .model .set_state_dict (state_dict )
572
+ old_state_dict = self .model .state_dict ()
573
+ new_state_dict = {}
574
+ for k , v in state_dict .items ():
575
+ if k not in old_state_dict or id (v ) != id (old_state_dict [k ]):
576
+ new_state_dict [k ] = v
577
+ self .model .set_state_dict (new_state_dict )
573
578
else :
574
579
if resume_from_checkpoint is not None and (self .args .dataset_rank == 0 or self .args .use_expert_parallel ):
575
580
@@ -891,7 +896,8 @@ def _inner_training_loop(
891
896
892
897
npu_accelerate_plugin (self .optimizer )
893
898
894
- self .timers and self .timers ("read-data" ).start ()
899
+ if self .args .ignore_data_skip :
900
+ self .timers and self .timers ("read-data" ).start ()
895
901
896
902
for epoch in range (epochs_trained , num_train_epochs ):
897
903
if isinstance (train_dataloader , paddle .io .DataLoader ) and isinstance (
@@ -907,7 +913,9 @@ def _inner_training_loop(
907
913
inputs = split_inputs_sequence_dim (inputs )
908
914
if self .args .use_hybrid_parallel and self .args .context_parallel_degree > 1 :
909
915
inputs = split_inputs_sequence_dim_load_balance (inputs )
910
- self .timers and self .timers ("read-data" ).stop ()
916
+ if self .args .ignore_data_skip :
917
+ self .timers and self .timers ("read-data" ).stop ()
918
+
911
919
os .environ ["TRAINER_GLOBAL_STEP" ] = str (self .state .global_step )
912
920
self .callback_handler .on_load_data_end (args , self .state , self .control , inputs = inputs )
913
921
@@ -1098,7 +1106,9 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
1098
1106
1099
1107
if self .control .should_epoch_stop or self .control .should_training_stop :
1100
1108
break
1101
- self .timers and self .timers ("read-data" ).start ()
1109
+
1110
+ if self .args .ignore_data_skip :
1111
+ self .timers and self .timers ("read-data" ).start ()
1102
1112
1103
1113
if step < 0 :
1104
1114
logger .warning (
@@ -2462,10 +2472,15 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
2462
2472
if state_dict is None :
2463
2473
state_dict = self .model .state_dict ()
2464
2474
2465
- self ._save_ckpt_func (
2466
- state_dict ,
2467
- os .path .join (output_dir , _add_variant (PADDLE_WEIGHTS_NAME , self .args .weight_name_suffix )),
2468
- )
2475
+ if self .args .should_save_sharding_stage1_model :
2476
+ state_dict , _ , _ = self .sharding_io .manipulate_state_dict_and_config (
2477
+ unwrap_model (self .model ), merge_tensor_parallel = False , state_dict = state_dict
2478
+ )
2479
+ variant = _add_variant (PADDLE_WEIGHTS_NAME , self .args .sharded_name_suffix ())
2480
+ else :
2481
+ variant = _add_variant (PADDLE_WEIGHTS_NAME , self .args .weight_name_suffix )
2482
+
2483
+ self ._save_ckpt_func (state_dict , os .path .join (output_dir , variant ))
2469
2484
else :
2470
2485
if isinstance (self .model , PretrainedModel ) and self .args .should_save_sharding_stage1_model :
2471
2486
config_to_save = None
0 commit comments