Skip to content

Commit 921fc44

Browse files
authored
[Unified Checkpoint] Support sharding_comm_overlap (#9392)
* update
1 parent 018b530 commit 921fc44

File tree

7 files changed

+52
-28
lines changed

7 files changed

+52
-28
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@
141141
set_seed,
142142
should_skip_data,
143143
speed_metrics,
144+
split_parallel_config,
144145
)
145146
from .training_args import TrainingArguments
146147
from .unified_checkpoint import UnifiedCheckpointHandler
@@ -2053,6 +2054,14 @@ def get_expected_keys(inputs, keys):
20532054
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
20542055
self.optimizer = fleet.distributed_optimizer(self.optimizer)
20552056

2057+
if (
2058+
hasattr(self.args, "enable_sharding_comm_overlap")
2059+
and self.args.enable_sharding_comm_overlap
2060+
and self.args.unified_checkpoint
2061+
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
2062+
):
2063+
model.register_sharding_comm_overlap_hook(self.optimizer)
2064+
20562065
# No pipeline mode, sharding only
20572066
if not in_pipeline_parallel_mode and in_sharding_parallel_mode:
20582067
# Sharded DDP!
@@ -2840,8 +2849,15 @@ def _load_optimizer_and_scheduler(self, checkpoint):
28402849
else:
28412850
opt_state_dict = None
28422851
else:
2852+
model = self.model
2853+
if (
2854+
hasattr(self.args, "enable_sharding_comm_overlap")
2855+
and self.args.enable_sharding_comm_overlap
2856+
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
2857+
):
2858+
model = self.model_wrapped
28432859
opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer(
2844-
model=self.model,
2860+
model=model,
28452861
optimizer=self.optimizer,
28462862
resume_from_checkpoint=checkpoint,
28472863
)

paddlenlp/trainer/trainer_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,3 +1126,11 @@ def should_skip_data(global_step, skip_data_intervals):
11261126
skip_flag = True
11271127
break
11281128
return skip_flag
1129+
1130+
1131+
def split_parallel_config(parallel_config):
1132+
if "," in parallel_config:
1133+
parallel_config = set(parallel_config.split(","))
1134+
else:
1135+
parallel_config = set(parallel_config.split(" "))
1136+
return parallel_config

paddlenlp/trainer/training_args.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
OptimizerNames,
3838
SchedulerType,
3939
ShardingOption,
40+
split_parallel_config,
4041
)
4142

4243
try:
@@ -1096,13 +1097,6 @@ def __post_init__(self):
10961097
logger.warning("set amp_master_grad to false since amp is disabled.")
10971098
self.amp_master_grad = False
10981099

1099-
def split_parallel_config(parallel_config):
1100-
if "," in parallel_config:
1101-
parallel_config = set(parallel_config.split(","))
1102-
else:
1103-
parallel_config = set(parallel_config.split(" "))
1104-
return parallel_config
1105-
11061100
# use_hybrid_parallel
11071101
if self.use_hybrid_parallel:
11081102

@@ -1155,29 +1149,20 @@ def split_parallel_config(parallel_config):
11551149
or "enable_dp_comm_overlap" in pipeline_parallel_config
11561150
)
11571151
enable_dp_comm_overlap = using_comm_overlap and self.data_parallel_degree > 1
1158-
enable_sharding_comm_overlap = using_comm_overlap and self.sharding_parallel_degree > 1
1152+
self.enable_sharding_comm_overlap = using_comm_overlap and self.sharding_parallel_degree > 1
11591153
assert not (
1160-
enable_dp_comm_overlap and enable_sharding_comm_overlap
1154+
enable_dp_comm_overlap and self.enable_sharding_comm_overlap
11611155
), "dp_comm_overlap and sharding_comm_overlap cannot be enabled at the same time"
11621156

1163-
if enable_sharding_comm_overlap and not self.amp_master_grad:
1157+
if self.enable_sharding_comm_overlap and not self.amp_master_grad:
11641158
raise ValueError(
11651159
"If `enable_sharding_comm_overlap` in pipeline_parallel_configs, `amp_master_grad` must be True."
11661160
)
1167-
if (
1168-
enable_sharding_comm_overlap
1169-
and self.unified_checkpoint
1170-
and "split_param" in split_parallel_config(self.sharding_parallel_config)
1171-
):
1172-
logger.warning(
1173-
"Currently unified checkpoint do not support using `sharding_comm_overlap` and `split_param` at the same time, delete `sharding_comm_overlap`."
1174-
)
1175-
enable_sharding_comm_overlap = False
11761161

11771162
dygraph_pp_configs = {
11781163
"delay_scale_loss": True if "enable_delay_scale_loss" in pipeline_parallel_config else False,
11791164
"dp_comm_overlap": enable_dp_comm_overlap,
1180-
"sharding_comm_overlap": enable_sharding_comm_overlap,
1165+
"sharding_comm_overlap": self.enable_sharding_comm_overlap,
11811166
"enable_timer": "enable_timer" in pipeline_parallel_config,
11821167
"release_gradients": "enable_release_grads" in pipeline_parallel_config or self.release_grads,
11831168
"overlap_p2p_comm": "enable_overlap_p2p_comm" in pipeline_parallel_config,

paddlenlp/trainer/unified_checkpoint/check_completion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ def check_unified_optimizer(args, model, optimizer, resume_from_checkpoint, safe
150150
sharding_group = hcg.get_sharding_parallel_group()
151151
sharding_rank = sharding_group.rank
152152
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
153-
struct2static_name_mappings = {k: v.name for k, v in model.state_dict().items()}
153+
model_state_dict = get_expected_state_dict(model)
154+
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()}
154155

155156
if is_sharding_split_param_mode(args):
156157
# We do not check optimizer files completion for split_param, since it is very complicated. Directly support local resume.

paddlenlp/trainer/unified_checkpoint/load_local.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def _remove_unused_keys(
150150
def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False):
151151
# Special process with split param.
152152
if is_sharding_split_param_mode(args):
153-
returned_optim_state_dict = load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint)
153+
returned_optim_state_dict = load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint)
154154
return returned_optim_state_dict
155155

156156
# init and get optimizer LR_Scheduler

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515

1616
import gc
1717
import os
18+
from itertools import chain
1819

1920
import paddle
2021
import paddle.distributed as dist
2122
from paddle.distributed import fleet
2223
from tqdm.auto import tqdm
2324

2425
from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
25-
from paddlenlp.transformers.model_utils import load_state_dict
26+
from paddlenlp.transformers.model_utils import load_state_dict, unwrap_model
2627
from paddlenlp.utils.env import (
2728
SAFE_MASTER_WEIGHTS_INDEX_NAME,
2829
SAFE_OPTIMIZER_INDEX_NAME,
@@ -97,6 +98,7 @@ def gather_splited_param_for_optimizer(optimizer):
9798
global_rank = dist.get_rank()
9899
param_slice_info = {}
99100
param_shape_info = {}
101+
100102
for buffer in optimizer._inner_opt._comm_buffer_list:
101103
for key in buffer._sharding_param_grad_view.keys():
102104
param_slice_info[key] = (
@@ -153,7 +155,7 @@ def gather_splited_param_for_optimizer(optimizer):
153155
return optim_state_dict, master_weights
154156

155157

156-
def load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint):
158+
def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint):
157159
returned_optim_state_dict = nested_copy(optimizer.state_dict())
158160

159161
index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME
@@ -177,7 +179,13 @@ def load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint)
177179
expected_keys = []
178180
param_slice_info = {}
179181
param_shape_info = {}
180-
for buffer in optimizer._inner_opt._comm_buffer_list:
182+
183+
comm_buffer_list = optimizer._inner_opt._comm_buffer_list
184+
if hasattr(args, "enable_sharding_comm_overlap") and args.enable_sharding_comm_overlap:
185+
comm_buffer_list = list(chain(*model._chunk_2_comm_buffers.values()))
186+
model = unwrap_model(model)
187+
188+
for buffer in comm_buffer_list:
181189
for key in buffer._sharding_param_grad_view.keys():
182190
begin = buffer._sharding_param_grad_view[key]._param_begin
183191
end = buffer._sharding_param_grad_view[key]._param_end

paddlenlp/trainer/unified_checkpoint/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
2525
from paddlenlp.trainer.trainer_utils import ExplicitEnum, ShardingOption
2626
from paddlenlp.trainer.utils.helper import distributed_isfile
27-
from paddlenlp.transformers.model_utils import PretrainedModel, get_parameter_dtype
27+
from paddlenlp.transformers.model_utils import (
28+
PretrainedModel,
29+
get_parameter_dtype,
30+
unwrap_model,
31+
)
2832
from paddlenlp.transformers.utils import dtype_byte_size
2933
from paddlenlp.utils.distributed import distributed_allgather, distributed_gather
3034
from paddlenlp.utils.env import (
@@ -193,6 +197,8 @@ def get_expected_state_dict(model_to_save):
193197
"""
194198
Get trainable state_dict of model_to_save.
195199
"""
200+
model_to_save = unwrap_model(model_to_save)
201+
196202
if isinstance(model_to_save, PretrainedModel):
197203
state_dict = model_to_save.state_dict()
198204
if (
@@ -221,7 +227,7 @@ def get_expected_keys(args, sharded_metadata, model, optimizer, is_master_weight
221227
params2rank = optimizer._param2rank
222228

223229
model_state_dict = get_expected_state_dict(model)
224-
struct2static_name_mappings = {k: v.name for k, v in get_expected_state_dict(model).items()}
230+
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()}
225231

226232
expected_keys = []
227233
for key in list(sharded_metadata["all_optimizer_keys"]):

0 commit comments

Comments
 (0)