@@ -354,9 +354,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
354
354
"""
355
355
hcg = fleet .get_hybrid_communicate_group ()
356
356
tp_group = hcg .get_model_parallel_group ()
357
- dp_group = hcg .get_data_parallel_group ()
358
357
tp_rank = tp_group .rank
359
- dp_rank = dp_group .rank if dp_group .nranks > 1 else 0
360
358
361
359
# filter actions for pipeline mode
362
360
if hcg .get_pipe_parallel_group ().nranks > 1 :
@@ -373,10 +371,9 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
373
371
if i > len (filter_keys ) - 1 :
374
372
continue
375
373
key = filter_keys [i ]
376
- tensor = state_dict [key ]
377
- # When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0.
378
- if dp_rank > 0 and not getattr (tensor , "no_sync" , False ):
374
+ if key not in state_dict :
379
375
continue
376
+ tensor = state_dict [key ]
380
377
if key in tp_actions :
381
378
# Get tensor size
382
379
tensor_bytes = tensor .numel ().item () * dtype_byte_size (tensor .dtype ) * tp_group .nranks
@@ -405,21 +402,13 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
405
402
return state_dict_to_save
406
403
407
404
408
- def merge_tensor_parallel_for_optimizer (state_dict , tp_actions , all_filter_keys , model_state_dict = None ):
405
+ def merge_tensor_parallel_for_optimizer (state_dict , tp_actions , all_filter_keys ):
409
406
"""
410
407
Merge tensor parallel according to tp_actions, used for master_weight and optimizer weight.
411
408
"""
412
409
hcg = fleet .get_hybrid_communicate_group ()
413
410
tp_group = hcg .get_model_parallel_group ()
414
- dp_group = hcg .get_data_parallel_group ()
415
411
tp_rank = tp_group .rank
416
- dp_rank = dp_group .rank if dp_group .nranks > 1 else 0
417
-
418
- no_sync_kname = []
419
- if model_state_dict is not None :
420
- for k , v in model_state_dict .items ():
421
- if getattr (v , "no_sync" , False ):
422
- no_sync_kname .append (k )
423
412
424
413
state_dict_to_save = {}
425
414
max_key_len = max ([len (_ ) for _ in all_filter_keys ])
@@ -430,10 +419,9 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys,
430
419
continue
431
420
# get base model key
432
421
model_key = filter_keys [i ].split ("/" )[0 ]
433
- tensor = state_dict [filter_keys [i ]]
434
- # When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0.
435
- if dp_rank > 0 and model_key not in no_sync_kname :
422
+ if filter_keys [i ] not in state_dict :
436
423
continue
424
+ tensor = state_dict [filter_keys [i ]]
437
425
if model_key in tp_actions :
438
426
# for example: beta1, beta2
439
427
if tensor .numel ().item () == 1 :
0 commit comments