30
30
from paddlenlp .transformers .model_utils import (
31
31
PretrainedModel ,
32
32
_load_state_dict_into_model ,
33
+ faster_set_state_dict ,
33
34
get_parameter_dtype ,
34
35
load_state_dict ,
35
36
unwrap_model ,
64
65
from paddlenlp .utils .log import logger
65
66
66
67
if is_safetensors_available ():
67
- from safetensors import safe_open
68
+ # from safetensors import safe_open
68
69
from safetensors .numpy import save_file as safe_save_file
69
70
71
+ from paddlenlp .utils .safetensors import fast_safe_open as safe_open
70
72
71
73
FP32_MASTER = "fp32_master_0"
72
74
optimizer_scalar_name = [
@@ -195,7 +197,6 @@ def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str,
195
197
Returns:
196
198
None
197
199
"""
198
-
199
200
if paddle .distributed .get_world_size () <= 1 :
200
201
load_single_card_checkpoint (args , model , resume_from_checkpoint )
201
202
return
@@ -221,7 +222,6 @@ def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, sa
221
222
pretrained_model_name_or_path = resume_from_checkpoint ,
222
223
index_filename = os .path .join (resume_from_checkpoint , index_filename ),
223
224
)
224
-
225
225
loaded_keys = sharded_metadata ["all_checkpoint_keys" ]
226
226
227
227
model_state_dict = get_expected_state_dict (model )
@@ -265,7 +265,9 @@ def _remove_unused_keys(
265
265
else :
266
266
tp_actions = model .get_tensor_parallel_convert_actions (model .config , loaded_keys , ignore_error = True )
267
267
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
268
- state_dict = load_state_dict (shard_file , tp_actions if pre_tensor_parallel_split else None , expected_keys )
268
+ state_dict = load_state_dict (
269
+ shard_file , tp_actions if pre_tensor_parallel_split else None , expected_keys , device = "expected"
270
+ )
269
271
270
272
if not pre_tensor_parallel_split :
271
273
# Since we load all keys but we only need one of pipeline stages
@@ -278,11 +280,12 @@ def _remove_unused_keys(
278
280
None , model .config , state_dict = state_dict , ignore_error = len (resolved_archive_file ) > 1
279
281
)
280
282
281
- error_msgs += _load_state_dict_into_model (model , state_dict , "" )
283
+ # error_msgs += _load_state_dict_into_model(model, state_dict, "")
284
+ error_msgs += faster_set_state_dict (model , state_dict , strict_dtype = False )
282
285
283
286
# force memory release
284
287
del state_dict
285
- gc .collect ()
288
+ # gc.collect()
286
289
287
290
if len (error_msgs ) > 0 :
288
291
error_msg = "\n \t " .join (error_msgs )
@@ -336,6 +339,7 @@ def unified_checkpoint_into_shards(
336
339
tp_actions = model_to_save .get_tensor_parallel_convert_actions (
337
340
model_to_save .config , state_dict .keys (), is_split = False , ignore_error = True
338
341
)
342
+ logger .info ("Unified model tensor parallel weights in shards" )
339
343
state_dict = merge_tensor_parallel_with_shard (state_dict , tp_actions , all_filter_keys )
340
344
341
345
# build index json file
@@ -489,6 +493,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
489
493
# This should always be a list but, just to be sure.
490
494
if not isinstance (resolved_archive_file , list ):
491
495
resolved_archive_file = [resolved_archive_file ]
496
+
492
497
if len (resolved_archive_file ) > 1 :
493
498
resolved_archive_file = tqdm (resolved_archive_file , desc = "Loading optimizer shards" )
494
499
@@ -536,10 +541,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
536
541
tp_actions = mapping_optimizer_tp_actions (tp_actions , expected_keys )
537
542
538
543
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
539
- state_dict = load_state_dict (shard_file , tp_actions , expected_keys )
544
+ state_dict = load_state_dict (shard_file , tp_actions , expected_keys , device = "expected" )
540
545
else :
541
546
# for pipeline model, we don't need to use tp_actions
542
- state_dict = load_state_dict (shard_file , None , expected_keys )
547
+ state_dict = load_state_dict (shard_file , None , expected_keys , device = "expected" )
543
548
544
549
returned_state_dict .update (state_dict )
545
550
# force memory release
@@ -552,7 +557,6 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
552
557
state_dict_master_weight = load_resolved_archive_file (
553
558
resolved_archive_file_mw , sharded_metadata_mw , expected_keys_mw , is_master_weights = True
554
559
)
555
-
556
560
# rename optimizer param
557
561
for key in list (state_dict_optim .keys ()):
558
562
key_name = key .split ("/" )
@@ -561,13 +565,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
561
565
key_name = "_" .join ([static_name , FP32_MASTER , key_name [1 ]])
562
566
else :
563
567
key_name = "_" .join ([static_name , key_name [1 ]])
564
- returned_optim_state_dict [key_name ] = state_dict_optim [ key ]
568
+ returned_optim_state_dict [key_name ] = state_dict_optim . pop ( key )
565
569
returned_optim_state_dict [key_name ].name = key_name
566
570
567
571
if has_master_weights :
568
572
for key in list (state_dict_master_weight .keys ()):
569
573
static_name = struct2static_name_mappings [key ]
570
- returned_optim_state_dict ["master_weights" ][static_name ] = state_dict_master_weight [ key ]
574
+ returned_optim_state_dict ["master_weights" ][static_name ] = state_dict_master_weight . pop ( key )
571
575
returned_optim_state_dict ["master_weights" ][static_name ].name = "_" .join ([static_name , FP32_MASTER ])
572
576
573
577
returned_optim_state_dict = nested_copy_place (
@@ -639,6 +643,7 @@ def unified_optimizer_into_shards(
639
643
tp_actions = model .get_tensor_parallel_convert_actions (
640
644
model .config , model_keys , is_split = False , ignore_error = True
641
645
)
646
+ logger .info ("Unified optimizer tensor parallel in shards" )
642
647
optim_state_dict = merge_tensor_parallel_for_optimizer (
643
648
optim_state_dict ,
644
649
tp_actions ,
@@ -647,6 +652,7 @@ def unified_optimizer_into_shards(
647
652
paddle .device .cuda .empty_cache ()
648
653
649
654
if master_weights is not None :
655
+ logger .info ("Unified master weight tensor parallel in shards" )
650
656
master_weights = merge_tensor_parallel_for_optimizer (
651
657
master_weights ,
652
658
tp_actions ,
@@ -702,7 +708,6 @@ def unified_optimizer_into_shards(
702
708
def check_unified_checkpoint (args , model , resume_from_checkpoint , safe_serialization = False ):
703
709
index_filename = select_model_weight_index (args , model , resume_from_checkpoint , safe_serialization , local = False )
704
710
index_filename = os .path .join (resume_from_checkpoint , index_filename )
705
-
706
711
# Find index json file and distribute this file in global group.
707
712
if distributed_isfile (index_filename ):
708
713
distributed_file (index_filename )
@@ -1604,7 +1609,9 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False):
1604
1609
tp_group = hcg .get_model_parallel_group ()
1605
1610
pp_group = hcg .get_pipe_parallel_group ()
1606
1611
1607
- logger .info ("Unified checkpoint generating sharded_index json files." )
1612
+ logger .info (
1613
+ f"Unified checkpoint: generating sharded_index json files for { 'optimizer or master weight' if is_optimizer else 'model weight' } ."
1614
+ )
1608
1615
1609
1616
if tp_group .nranks > 1 :
1610
1617
dist .all_gather_object (index_file_list , index_file , tp_group )
@@ -1713,8 +1720,6 @@ def filter_params(model_to_save, state_dict, is_optimizer=False):
1713
1720
1714
1721
1715
1722
def merge_tensor_parallel_with_shard (state_dict , tp_actions , all_filter_keys ):
1716
- logger .info ("Unified checkpoint merge tensor parallel in shards" )
1717
-
1718
1723
hcg = fleet .get_hybrid_communicate_group ()
1719
1724
tp_group = hcg .get_model_parallel_group ()
1720
1725
tp_rank = tp_group .rank
@@ -1740,7 +1745,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
1740
1745
action = tp_actions .pop (key )
1741
1746
tensor = action (ret ) if is_dst else None
1742
1747
else :
1743
- tensor = tensor ._copy_to (paddle .CPUPlace (), False ) if is_dst else None
1748
+ tensor = tensor ._copy_to (paddle .CUDAPinnedPlace (), False ) if is_dst else None
1744
1749
1745
1750
if is_dst :
1746
1751
state_dict_to_save [key ] = tensor
@@ -1753,8 +1758,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
1753
1758
1754
1759
1755
1760
def merge_tensor_parallel_for_optimizer (state_dict , tp_actions , all_filter_keys ):
1756
- logger .info ("Unified optimizer tensor parallel in shards" )
1757
-
1761
+ # Core function for UC
1758
1762
hcg = fleet .get_hybrid_communicate_group ()
1759
1763
tp_group = hcg .get_model_parallel_group ()
1760
1764
tp_rank = tp_group .rank
@@ -1773,14 +1777,14 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
1773
1777
# for example: beta1, beta2
1774
1778
if tensor .numel ().item () == 1 :
1775
1779
tensor = (
1776
- tensor ._copy_to (paddle .CPUPlace (), False ) if is_dst else None
1780
+ tensor ._copy_to (paddle .CUDAPinnedPlace (), False ) if is_dst else None
1777
1781
) # Need broadcast when loaded
1778
1782
else :
1779
1783
ret = distributed_gather (tensor , dst = j , group = tp_group , offload = False )
1780
1784
action = tp_actions [model_key ]
1781
1785
tensor = action (ret ) if is_dst else None
1782
1786
else :
1783
- tensor = tensor ._copy_to (paddle .CPUPlace (), False ) if is_dst else None
1787
+ tensor = tensor ._copy_to (paddle .CUDAPinnedPlace (), False ) if is_dst else None
1784
1788
1785
1789
if is_dst :
1786
1790
state_dict_to_save [filter_keys [i ]] = tensor
@@ -1892,7 +1896,10 @@ def nested_copy_place(inputs, place=None, blocking=False):
1892
1896
outputs [key ] = nested_copy_place (inputs [key ], place , blocking )
1893
1897
return outputs
1894
1898
if isinstance (inputs , paddle .Tensor ):
1895
- inputs = inputs if inputs .place == place else inputs ._copy_to (place , blocking )
1899
+ if inputs .place ._equals (place ):
1900
+ return inputs
1901
+ else :
1902
+ return inputs ._copy_to (place , blocking )
1896
1903
return inputs
1897
1904
1898
1905
0 commit comments