36
36
get_expected_state_dict ,
37
37
get_optimizer_shard_files ,
38
38
mapping_optimizer_tp_actions ,
39
+ update_master_weight_status ,
39
40
)
40
41
41
42
__all__ = ["gather_splited_param_for_optimizer" , "load_unified_optimizer_split_param" ]
42
43
43
44
44
45
def merge_splited_param (
45
- state_dict , partial_tensor_list , param_shape_info , send_table , recv_table , is_master_weights = False
46
+ state_dict ,
47
+ partial_tensor_list ,
48
+ param_shape_info ,
49
+ send_table ,
50
+ recv_table ,
51
+ is_master_weights = False ,
52
+ ckpt_quant_stage = "O0" ,
46
53
):
47
54
"""Merge the splited param in sharding group."""
48
55
global_rank = dist .get_rank ()
49
56
for key in list (state_dict .keys ()):
50
- if state_dict [key ].numel (). item ( ) == 1 : # for example: beta1, beta2
57
+ if int ( state_dict [key ].numel ()) == 1 : # for example: beta1, beta2
51
58
continue
52
59
53
60
static_name = key if is_master_weights else generate_base_static_name (key )[0 ]
@@ -89,10 +96,21 @@ def merge_splited_param(
89
96
)
90
97
dist .stream .send (tensor , dst = recv_rank )
91
98
state_dict .pop (key )
99
+
100
+ if ckpt_quant_stage != "O0" :
101
+ for key in list (state_dict .keys ()):
102
+ if int (state_dict [key ].numel ()) == 1 : # for example: beta1, beta2
103
+ static_name = key if is_master_weights else generate_base_static_name (key )[0 ]
104
+ if static_name in partial_tensor_list :
105
+ recv_rank = recv_table [static_name ]
106
+ send_info = send_table [static_name ]
107
+ if global_rank != recv_rank :
108
+ state_dict .pop (key )
109
+
92
110
return state_dict
93
111
94
112
95
- def gather_splited_param_for_optimizer (optimizer ):
113
+ def gather_splited_param_for_optimizer (optimizer , ckpt_quant_stage = "O0" ):
96
114
hcg = fleet .get_hybrid_communicate_group ()
97
115
sharding_group = hcg .get_sharding_parallel_group ()
98
116
global_rank = dist .get_rank ()
@@ -127,7 +145,7 @@ def gather_splited_param_for_optimizer(optimizer):
127
145
for key in list (optim_state_dict .keys ()):
128
146
static_name , _ = generate_base_static_name (key )
129
147
if static_name in param_slice_info .keys ():
130
- if optim_state_dict [key ].numel (). item ( ) == 1 : # for example: beta1, beta2
148
+ if int ( optim_state_dict [key ].numel ()) == 1 : # for example: beta1, beta2
131
149
continue
132
150
begin , end = param_slice_info [static_name ]
133
151
shape , numel , _ , _ = param_shape_info [static_name ]
@@ -149,13 +167,15 @@ def gather_splited_param_for_optimizer(optimizer):
149
167
recv_table [key ] = sharding_ranklist [0 ][0 ] # which sharding_rank to recv the splited tensor
150
168
send_table [key ] = [(rank , begin , end ) for rank , begin , end in sharding_ranklist ]
151
169
152
- merge_splited_param (optim_state_dict , partial_tensor_list , param_shape_info , send_table , recv_table , False )
170
+ merge_splited_param (
171
+ optim_state_dict , partial_tensor_list , param_shape_info , send_table , recv_table , False , ckpt_quant_stage
172
+ )
153
173
if master_weights is not None :
154
174
merge_splited_param (master_weights , partial_tensor_list , param_shape_info , send_table , recv_table , True )
155
175
return optim_state_dict , master_weights
156
176
157
177
158
- def load_unified_optimizer_split_param (args , model , optimizer , resume_from_checkpoint ):
178
+ def load_unified_optimizer_split_param (args , model , optimizer , resume_from_checkpoint , ckpt_quant_stage = "O0" ):
159
179
returned_optim_state_dict = nested_copy (optimizer .state_dict ())
160
180
161
181
index_filename , index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME , SAFE_MASTER_WEIGHTS_INDEX_NAME
@@ -208,6 +228,10 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
208
228
if len (resolved_archive_file ) > 1 :
209
229
resolved_archive_file = tqdm (resolved_archive_file , desc = "Loading optimizer shards" )
210
230
231
+ has_master_weights , index_filename_master_weights = update_master_weight_status (
232
+ args , optimizer , has_master_weights , safe_serialization = True
233
+ )
234
+
211
235
if has_master_weights :
212
236
returned_optim_state_dict ["master_weights" ] = {}
213
237
resolved_archive_file_mw , sharded_metadata_mw = get_optimizer_shard_files (
@@ -217,7 +241,9 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
217
241
if len (resolved_archive_file_mw ) > 1 :
218
242
resolved_archive_file_mw = tqdm (resolved_archive_file_mw , desc = "Loading master weights shards" )
219
243
220
- def load_resolved_archive_file (resolved_archive_file , sharded_metadata , expected_keys , is_master_weights = False ):
244
+ def load_resolved_archive_file (
245
+ resolved_archive_file , sharded_metadata , expected_keys , is_master_weights = False , ckpt_quant_stage = "O0"
246
+ ):
221
247
returned_state_dict = {}
222
248
223
249
if model .config .tensor_parallel_degree > 1 :
@@ -232,24 +258,38 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
232
258
if expected_keys .isdisjoint (sharded_metadata ["file_map" ][os .path .split (shard_file )[- 1 ]]):
233
259
continue
234
260
if model .config .tensor_parallel_degree > 1 :
235
- state_dict = load_state_dict (shard_file , tp_actions , expected_keys , device = "cpu" )
261
+ state_dict = load_state_dict (
262
+ shard_file ,
263
+ tp_actions ,
264
+ expected_keys ,
265
+ device = "cpu" ,
266
+ ckpt_quant_stage = ckpt_quant_stage ,
267
+ )
236
268
else :
237
- state_dict = load_state_dict (shard_file , None , expected_keys , device = "cpu" )
269
+ state_dict = load_state_dict (
270
+ shard_file ,
271
+ None ,
272
+ expected_keys ,
273
+ device = "cpu" ,
274
+ ckpt_quant_stage = ckpt_quant_stage ,
275
+ )
238
276
returned_state_dict .update (state_dict )
239
277
del state_dict
240
278
gc .collect ()
241
279
242
280
return returned_state_dict
243
281
244
282
# get tp params
245
- state_dict_optim = load_resolved_archive_file (resolved_archive_file , sharded_metadata , expected_keys_optim )
283
+ state_dict_optim = load_resolved_archive_file (
284
+ resolved_archive_file , sharded_metadata , expected_keys_optim , ckpt_quant_stage = ckpt_quant_stage
285
+ )
246
286
247
287
# need to split param for different sharding rank, maybe need to deal with oom issue.
248
288
for key in list (state_dict_optim .keys ()):
249
289
key_name = key .split ("/" )
250
290
static_name = struct2static_name_mappings .get (key_name [0 ], None )
251
291
252
- if state_dict_optim [key ].numel (). item ( ) > 1 :
292
+ if int ( state_dict_optim [key ].numel ()) > 1 :
253
293
begin , end = param_slice_info [static_name ]
254
294
shape , numel , index , padded_size = param_shape_info [static_name ]
255
295
state_dict_optim [key ] = state_dict_optim [key ].reshape ([- 1 ])
@@ -284,7 +324,7 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
284
324
285
325
for key in list (state_dict_master_weight .keys ()):
286
326
static_name = struct2static_name_mappings .get (key , None )
287
- if state_dict_master_weight [key ].numel (). item ( ) > 1 :
327
+ if int ( state_dict_master_weight [key ].numel ()) > 1 :
288
328
begin , end = param_slice_info [static_name ]
289
329
shape , numel , index , padded_size = param_shape_info [static_name ]
290
330
state_dict_master_weight [key ] = state_dict_master_weight [key ].reshape ([- 1 ])
@@ -303,6 +343,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
303
343
paddle .framework ._current_expected_place (), False
304
344
)
305
345
returned_optim_state_dict ["master_weights" ][static_name ] = state_dict_master_weight .pop (key )
346
+
347
+ # master weight cast (only in remove_master_weight)
348
+ if returned_optim_state_dict ["master_weights" ][static_name ].dtype != paddle .float32 :
349
+ returned_optim_state_dict ["master_weights" ][static_name ] = paddle .cast (
350
+ returned_optim_state_dict ["master_weights" ][static_name ], dtype = paddle .float32
351
+ )
352
+
306
353
returned_optim_state_dict ["master_weights" ][static_name ].name = "_" .join ([static_name , FP32_MASTER ])
307
354
308
355
return returned_optim_state_dict
0 commit comments