Skip to content

Commit 0e96b0f

Browse files
LiYuRioMeiyim
andauthored
Cherry-pick some PRs from incubate/paddlenlp-fleety (#9253)
* support pp-sharding reshard (#9153) * support best unbalaced pp scheduler (#9235) * remove pp hack (#9189) --------- Co-authored-by: Meiyim <chenxuyi@baidu.com>
1 parent 3007c79 commit 0e96b0f

File tree

3 files changed

+34
-12
lines changed

3 files changed

+34
-12
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2270,13 +2270,6 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle
22702270
self._pp_data_buffer = []
22712271

22722272
model.train()
2273-
# hack pipeline-layers
2274-
# since the pipeline layer will check input is valid every iter.
2275-
# in same case, for example, batch size warmup, we need dynamic change gradient_accumulation_steps to implement.
2276-
config_backup = model.micro_batch_size, model.accumulate_steps
2277-
model.micro_batch_size = self.args.per_device_train_batch_size
2278-
model.accumulate_steps = self.args.gradient_accumulation_steps
2279-
22802273
if model._dp_comm_overlap or model._sharding_comm_overlap:
22812274
for _, buffers in model._chunk_2_comm_buffers.items():
22822275
for buffer in buffers:
@@ -2291,8 +2284,6 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle
22912284
with self.autocast_smart_context_manager():
22922285
loss = model.forward_backward_pipeline(inputs, self.scaler if self.do_grad_scaling else None)
22932286

2294-
model.micro_batch_size, model.accumulate_steps = config_backup
2295-
22962287
return loss.detach()
22972288

22982289
def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Optional[bool] = False):

paddlenlp/trainer/training_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,7 @@ def split_parallel_config(parallel_config):
11201120
"enable_clear_every_step_cache",
11211121
"enable_overlap_p2p_comm",
11221122
"disable_batch_p2p_comm",
1123+
"best_unbalanced_scheduler",
11231124
]:
11241125
raise ValueError(
11251126
f"Found unknown pipeline mode config {x}, accpet config is disable_p2p_cache_shape, disable_partial_send_recv."
@@ -1158,6 +1159,7 @@ def split_parallel_config(parallel_config):
11581159
"overlap_p2p_comm": "enable_overlap_p2p_comm" in pipeline_parallel_config,
11591160
"clear_every_step_cache": "enable_clear_every_step_cache" in pipeline_parallel_config,
11601161
"use_batch_p2p_comm": "disable_batch_p2p_comm" not in pipeline_parallel_config,
1162+
"best_unbalanced_scheduler": "best_unbalanced_scheduler" in pipeline_parallel_config,
11611163
}
11621164
if dygraph_pp_configs["dp_comm_overlap"]:
11631165
raise ValueError("overlap has accuracy issue") # TODO: fix `overalap` + `delay_scale` issue

paddlenlp/trainer/utils/reshard/pp_reshard.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
1615
from collections import OrderedDict
1716

1817
from paddle.distributed.fleet.model import PipelineParallel
@@ -46,6 +45,25 @@ def get_index_layer_func():
4645
return _GLOBAL_INDEX_LAYER_FUNC
4746

4847

48+
_GLOBAL_SNAME_TO_TNAME_FUNC = None
49+
50+
51+
def register_sname_to_tname_func(func):
52+
global _GLOBAL_SNAME_TO_TNAME_FUNC
53+
_GLOBAL_SNAME_TO_TNAME_FUNC = func
54+
55+
56+
def has_register_sname_to_tname_func():
57+
global _GLOBAL_SNAME_TO_TNAME_FUNC
58+
return _GLOBAL_SNAME_TO_TNAME_FUNC is not None
59+
60+
61+
def get_sname_to_tname_func():
62+
global _GLOBAL_SNAME_TO_TNAME_FUNC
63+
assert _GLOBAL_SNAME_TO_TNAME_FUNC is not None, "sname to tname func is not registered yet"
64+
return _GLOBAL_SNAME_TO_TNAME_FUNC
65+
66+
4967
class LayerNameScope:
5068
"""
5169
layer name scope for a layer, layer name of the same kind of layer will be named consecutively
@@ -206,6 +224,7 @@ def __init__(self):
206224
self._segments = OrderedDict()
207225
self._layer_to_segment = OrderedDict()
208226
self._param_to_tname = OrderedDict()
227+
self._wname_to_rname = OrderedDict()
209228

210229
def add_segment(self, start_index, end_index):
211230
segment = PipeLineSegment(start_index, end_index)
@@ -218,19 +237,24 @@ def add_layer(self, layer_index, layer_name, param_names):
218237
segment = self._layer_to_segment[layer_index]
219238
segment.add_layer(layer_name, param_names)
220239

221-
def build_name_mapping(self):
240+
def build_name_mapping(self, sname_to_tname=None):
222241
for (k, segment) in self._segments.items():
223242
for (i, layer) in segment.layers.items():
224243
for param in layer.params.items():
225244
(param_name, tensor_name) = param
226245
# map to a new name
227246
n_name = self._rename_mgr.get_new_param_name(layer.name, tensor_name)
247+
if sname_to_tname is not None:
248+
if param_name in sname_to_tname.keys():
249+
self._wname_to_rname[param_name] = sname_to_tname[param_name]
228250
# logger.info(f"{param_name} {tensor_name}=>{n_name}")
229251
self._param_to_tname[param_name] = (tensor_name, n_name)
230252

231253
def map_name(self, param_name, t_name):
232254
assert param_name in self._param_to_tname
233255
tensor_name, n_name = self._param_to_tname[param_name]
256+
if param_name in self._wname_to_rname:
257+
n_name = self._wname_to_rname[param_name]
234258
assert tensor_name == t_name
235259
return n_name
236260

@@ -261,6 +285,11 @@ def __init__(
261285
self._index_layers()
262286

263287
stage_segments = self._segment()
288+
if has_register_sname_to_tname_func():
289+
self._sname_to_tname = get_sname_to_tname_func()(pp_model)
290+
else:
291+
self._sname_to_tname = None
292+
264293
for (i, stage_seg) in enumerate(stage_segments):
265294
pipe_stage = PipeLineStage()
266295
self._stages.append(pipe_stage)
@@ -275,7 +304,7 @@ def __init__(
275304
self._layer_name_to_stage[layer_name] = i
276305

277306
for stage in self._stages:
278-
stage.build_name_mapping()
307+
stage.build_name_mapping(self._sname_to_tname)
279308

280309
def _index_layers(self):
281310
for layer_name in self._param_names_by_layer.keys():

0 commit comments

Comments
 (0)