@@ -224,12 +224,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
224
224
for shard_file in resolved_archive_file :
225
225
if expected_keys .isdisjoint (sharded_metadata ["file_map" ][os .path .split (shard_file )[- 1 ]]):
226
226
continue
227
-
228
227
if model .config .tensor_parallel_degree > 1 :
229
- state_dict = load_state_dict (shard_file , tp_actions , expected_keys , device = "expected " )
228
+ state_dict = load_state_dict (shard_file , tp_actions , expected_keys , device = "cpu " )
230
229
else :
231
- state_dict = load_state_dict (shard_file , None , expected_keys , device = "expected" )
232
-
230
+ state_dict = load_state_dict (shard_file , None , expected_keys , device = "cpu" )
233
231
returned_state_dict .update (state_dict )
234
232
del state_dict
235
233
gc .collect ()
@@ -238,13 +236,6 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
238
236
239
237
# get tp params
240
238
state_dict_optim = load_resolved_archive_file (resolved_archive_file , sharded_metadata , expected_keys_optim )
241
- if has_master_weights :
242
- state_dict_master_weight = load_resolved_archive_file (
243
- resolved_archive_file_mw ,
244
- sharded_metadata_mw ,
245
- expected_keys ,
246
- is_master_weights = True ,
247
- )
248
239
249
240
# need to split param for different sharding rank, maybe need to deal with oom issue.
250
241
for key in list (state_dict_optim .keys ()):
@@ -266,15 +257,24 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
266
257
paddle .zeros ([padding_end - padding_start ], dtype = state_dict_optim [key ].dtype ),
267
258
)
268
259
)
269
-
270
260
if has_master_weights :
271
261
key_name = "_" .join ([static_name , FP32_MASTER , key_name [1 ]])
272
262
else :
273
263
key_name = "_" .join ([static_name , key_name [1 ]])
264
+
265
+ state_dict_optim [key ] = state_dict_optim [key ]._copy_to (paddle .framework ._current_expected_place (), False )
266
+
274
267
returned_optim_state_dict [key_name ] = state_dict_optim .pop (key )
275
268
returned_optim_state_dict [key_name ].name = key_name
276
269
277
270
if has_master_weights :
271
+ state_dict_master_weight = load_resolved_archive_file (
272
+ resolved_archive_file_mw ,
273
+ sharded_metadata_mw ,
274
+ expected_keys ,
275
+ is_master_weights = True ,
276
+ )
277
+
278
278
for key in list (state_dict_master_weight .keys ()):
279
279
static_name = struct2static_name_mappings .get (key , None )
280
280
if state_dict_master_weight [key ].numel ().item () > 1 :
@@ -292,6 +292,9 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
292
292
paddle .zeros ([padding_end - padding_start ], dtype = state_dict_master_weight [key ].dtype ),
293
293
)
294
294
)
295
+ state_dict_master_weight [key ] = state_dict_master_weight [key ]._copy_to (
296
+ paddle .framework ._current_expected_place (), False
297
+ )
295
298
returned_optim_state_dict ["master_weights" ][static_name ] = state_dict_master_weight .pop (key )
296
299
returned_optim_state_dict ["master_weights" ][static_name ].name = "_" .join ([static_name , FP32_MASTER ])
297
300
0 commit comments