42
42
43
43
44
44
def merge_splited_param (
45
- state_dict , partial_tensor_list , param_shape_info , send_table , recv_table , is_master_weights = False
45
+ state_dict ,
46
+ partial_tensor_list ,
47
+ param_shape_info ,
48
+ send_table ,
49
+ recv_table ,
50
+ is_master_weights = False ,
51
+ ckpt_quant_stage = "O0" ,
46
52
):
47
53
"""Merge the splited param in sharding group."""
48
54
global_rank = dist .get_rank ()
49
55
for key in list (state_dict .keys ()):
50
- if state_dict [key ].numel (). item ( ) == 1 : # for example: beta1, beta2
56
+ if int ( state_dict [key ].numel ()) == 1 : # for example: beta1, beta2
51
57
continue
52
58
53
59
static_name = key if is_master_weights else generate_base_static_name (key )[0 ]
@@ -89,10 +95,21 @@ def merge_splited_param(
89
95
)
90
96
dist .stream .send (tensor , dst = recv_rank )
91
97
state_dict .pop (key )
98
+
99
+ if ckpt_quant_stage != "O0" :
100
+ for key in list (state_dict .keys ()):
101
+ if int (state_dict [key ].numel ()) == 1 : # for example: beta1, beta2
102
+ static_name = key if is_master_weights else generate_base_static_name (key )[0 ]
103
+ if static_name in partial_tensor_list :
104
+ recv_rank = recv_table [static_name ]
105
+ send_info = send_table [static_name ]
106
+ if global_rank != recv_rank :
107
+ state_dict .pop (key )
108
+
92
109
return state_dict
93
110
94
111
95
- def gather_splited_param_for_optimizer (optimizer ):
112
+ def gather_splited_param_for_optimizer (optimizer , ckpt_quant_stage = "O0" ):
96
113
hcg = fleet .get_hybrid_communicate_group ()
97
114
sharding_group = hcg .get_sharding_parallel_group ()
98
115
global_rank = dist .get_rank ()
@@ -127,7 +144,7 @@ def gather_splited_param_for_optimizer(optimizer):
127
144
for key in list (optim_state_dict .keys ()):
128
145
static_name , _ = generate_base_static_name (key )
129
146
if static_name in param_slice_info .keys ():
130
- if optim_state_dict [key ].numel (). item ( ) == 1 : # for example: beta1, beta2
147
+ if int ( optim_state_dict [key ].numel ()) == 1 : # for example: beta1, beta2
131
148
continue
132
149
begin , end = param_slice_info [static_name ]
133
150
shape , numel , _ , _ = param_shape_info [static_name ]
@@ -149,13 +166,17 @@ def gather_splited_param_for_optimizer(optimizer):
149
166
recv_table [key ] = sharding_ranklist [0 ][0 ] # which sharding_rank to recv the splited tensor
150
167
send_table [key ] = [(rank , begin , end ) for rank , begin , end in sharding_ranklist ]
151
168
152
- merge_splited_param (optim_state_dict , partial_tensor_list , param_shape_info , send_table , recv_table , False )
169
+ merge_splited_param (
170
+ optim_state_dict , partial_tensor_list , param_shape_info , send_table , recv_table , False , ckpt_quant_stage
171
+ )
153
172
if master_weights is not None :
154
- merge_splited_param (master_weights , partial_tensor_list , param_shape_info , send_table , recv_table , True )
173
+ merge_splited_param (
174
+ master_weights , partial_tensor_list , param_shape_info , send_table , recv_table , True , ckpt_quant_stage
175
+ )
155
176
return optim_state_dict , master_weights
156
177
157
178
158
- def load_unified_optimizer_split_param (args , model , optimizer , resume_from_checkpoint ):
179
+ def load_unified_optimizer_split_param (args , model , optimizer , resume_from_checkpoint , ckpt_quant_stage = "O0" ):
159
180
returned_optim_state_dict = nested_copy (optimizer .state_dict ())
160
181
161
182
index_filename , index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME , SAFE_MASTER_WEIGHTS_INDEX_NAME
@@ -217,7 +238,9 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
217
238
if len (resolved_archive_file_mw ) > 1 :
218
239
resolved_archive_file_mw = tqdm (resolved_archive_file_mw , desc = "Loading master weights shards" )
219
240
220
- def load_resolved_archive_file (resolved_archive_file , sharded_metadata , expected_keys , is_master_weights = False ):
241
+ def load_resolved_archive_file (
242
+ resolved_archive_file , sharded_metadata , expected_keys , is_master_weights = False , ckpt_quant_stage = "O0"
243
+ ):
221
244
returned_state_dict = {}
222
245
223
246
if model .config .tensor_parallel_degree > 1 :
@@ -232,24 +255,38 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
232
255
if expected_keys .isdisjoint (sharded_metadata ["file_map" ][os .path .split (shard_file )[- 1 ]]):
233
256
continue
234
257
if model .config .tensor_parallel_degree > 1 :
235
- state_dict = load_state_dict (shard_file , tp_actions , expected_keys , device = "cpu" )
258
+ state_dict = load_state_dict (
259
+ shard_file ,
260
+ tp_actions ,
261
+ expected_keys ,
262
+ device = "cpu" ,
263
+ ckpt_quant_stage = ckpt_quant_stage ,
264
+ )
236
265
else :
237
- state_dict = load_state_dict (shard_file , None , expected_keys , device = "cpu" )
266
+ state_dict = load_state_dict (
267
+ shard_file ,
268
+ None ,
269
+ expected_keys ,
270
+ device = "cpu" ,
271
+ ckpt_quant_stage = ckpt_quant_stage ,
272
+ )
238
273
returned_state_dict .update (state_dict )
239
274
del state_dict
240
275
gc .collect ()
241
276
242
277
return returned_state_dict
243
278
244
279
# get tp params
245
- state_dict_optim = load_resolved_archive_file (resolved_archive_file , sharded_metadata , expected_keys_optim )
280
+ state_dict_optim = load_resolved_archive_file (
281
+ resolved_archive_file , sharded_metadata , expected_keys_optim , ckpt_quant_stage = ckpt_quant_stage
282
+ )
246
283
247
284
# need to split param for different sharding rank, maybe need to deal with oom issue.
248
285
for key in list (state_dict_optim .keys ()):
249
286
key_name = key .split ("/" )
250
287
static_name = struct2static_name_mappings .get (key_name [0 ], None )
251
288
252
- if state_dict_optim [key ].numel (). item ( ) > 1 :
289
+ if int ( state_dict_optim [key ].numel ()) > 1 :
253
290
begin , end = param_slice_info [static_name ]
254
291
shape , numel , index , padded_size = param_shape_info [static_name ]
255
292
state_dict_optim [key ] = state_dict_optim [key ].reshape ([- 1 ])
@@ -284,7 +321,7 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
284
321
285
322
for key in list (state_dict_master_weight .keys ()):
286
323
static_name = struct2static_name_mappings .get (key , None )
287
- if state_dict_master_weight [key ].numel (). item ( ) > 1 :
324
+ if int ( state_dict_master_weight [key ].numel ()) > 1 :
288
325
begin , end = param_slice_info [static_name ]
289
326
shape , numel , index , padded_size = param_shape_info [static_name ]
290
327
state_dict_master_weight [key ] = state_dict_master_weight [key ].reshape ([- 1 ])
0 commit comments