@@ -2236,16 +2236,7 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
2236
2236
self .model_wrapped .get_all_parameters (convert2cpu = True )
2237
2237
2238
2238
if self .args .should_save_model_state :
2239
- unified_checkpoint_config_backup = self .args .unified_checkpoint_config
2240
- # backup and remove unified_checkpoint_config for not trine stage
2241
- if not self .is_in_train :
2242
- self .args .unified_checkpoint_config = []
2243
-
2244
2239
self ._save (output_dir = output_dir , merge_tensor_parallel = merge_tensor_parallel )
2245
-
2246
- # recover unified_checkpoint_config for not trine stage
2247
- if not self .is_in_train :
2248
- self .args .unified_checkpoint_config = unified_checkpoint_config_backup
2249
2240
else :
2250
2241
if self .args .unified_checkpoint and "async_save" in self .args .unified_checkpoint_config :
2251
2242
os .makedirs (output_dir , exist_ok = True )
@@ -2523,10 +2514,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
2523
2514
# Save a trained model and configuration using `save_pretrained()`.
2524
2515
# They can then be reloaded using `from_pretrained()`
2525
2516
2526
- local_rank = int (os .getenv ("PADDLE_RANK_IN_NODE" , 0 ))
2527
2517
if (
2528
2518
strtobool (os .getenv ("FLAG_LLM_PDC" , "False" ))
2529
- and local_rank == 0
2519
+ and paddle . distributed . get_rank () == 0
2530
2520
and self .args .unified_checkpoint
2531
2521
and "async_save" in self .args .unified_checkpoint_config
2532
2522
):
@@ -2537,9 +2527,10 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
2537
2527
"ignore_save_lr_and_optim" : self .args .ignore_save_lr_and_optim ,
2538
2528
"skip_save_model_weight" : "skip_save_model_weight" in self .args .unified_checkpoint_config ,
2539
2529
}
2540
- if not os .path .exists (os .path .join (self .args .logging_dir , "async_save_info.json" )):
2541
- with open (os .path .join (self .args .logging_dir , "async_save_info.json" ), "w" ) as f :
2542
- json .dump (save_info , f )
2530
+ if os .path .exists (os .path .join (self .args .logging_dir , "async_save_info.json" )): # afs cannot overwrite
2531
+ os .remove (os .path .join (self .args .logging_dir , "async_save_info.json" ))
2532
+ with open (os .path .join (self .args .logging_dir , "async_save_info.json" ), "w" ) as f :
2533
+ json .dump (save_info , f )
2543
2534
2544
2535
if self .args .should_save :
2545
2536
if self .tokenizer is not None :
@@ -2548,7 +2539,17 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
2548
2539
paddle .save (self .args , os .path .join (output_dir , TRAINING_ARGS_NAME ))
2549
2540
2550
2541
if self .args .unified_checkpoint :
2542
+ unified_checkpoint_config_backup = self .args .unified_checkpoint_config
2543
+ # backup and remove unified_checkpoint_config for not trine stage
2544
+ if not self .is_in_train :
2545
+ self .args .unified_checkpoint_config = []
2546
+
2551
2547
self .unified_checkpoint_handler .save_unified_checkpoint (self .model , self .optimizer , output_dir )
2548
+
2549
+ # recover unified_checkpoint_config for not trine stage
2550
+ if not self .is_in_train :
2551
+ self .args .unified_checkpoint_config = unified_checkpoint_config_backup
2552
+
2552
2553
return
2553
2554
2554
2555
merge_tensor_parallel = merge_tensor_parallel and self .args .use_hybrid_parallel
0 commit comments