From 602b7d019f7ebdc3f05344b678fc2522613bfec9 Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Thu, 26 Sep 2024 16:34:42 +0800 Subject: [PATCH 1/3] support pp-sharding reshard (#9153) --- paddlenlp/trainer/utils/reshard/pp_reshard.py | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/paddlenlp/trainer/utils/reshard/pp_reshard.py b/paddlenlp/trainer/utils/reshard/pp_reshard.py index 5c98e6069212..0caa5eb666c6 100644 --- a/paddlenlp/trainer/utils/reshard/pp_reshard.py +++ b/paddlenlp/trainer/utils/reshard/pp_reshard.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from collections import OrderedDict from paddle.distributed.fleet.model import PipelineParallel @@ -46,6 +45,25 @@ def get_index_layer_func(): return _GLOBAL_INDEX_LAYER_FUNC +_GLOBAL_SNAME_TO_TNAME_FUNC = None + + +def register_sname_to_tname_func(func): + global _GLOBAL_SNAME_TO_TNAME_FUNC + _GLOBAL_SNAME_TO_TNAME_FUNC = func + + +def has_register_sname_to_tname_func(): + global _GLOBAL_SNAME_TO_TNAME_FUNC + return _GLOBAL_SNAME_TO_TNAME_FUNC is not None + + +def get_sname_to_tname_func(): + global _GLOBAL_SNAME_TO_TNAME_FUNC + assert _GLOBAL_SNAME_TO_TNAME_FUNC is not None, "sname to tname func is not registered yet" + return _GLOBAL_SNAME_TO_TNAME_FUNC + + class LayerNameScope: """ layer name scope for a layer, layer name of the same kind of layer will be named consecutively @@ -206,6 +224,7 @@ def __init__(self): self._segments = OrderedDict() self._layer_to_segment = OrderedDict() self._param_to_tname = OrderedDict() + self._wname_to_rname = OrderedDict() def add_segment(self, start_index, end_index): segment = PipeLineSegment(start_index, end_index) @@ -218,19 +237,24 @@ def add_layer(self, layer_index, layer_name, param_names): segment = self._layer_to_segment[layer_index] segment.add_layer(layer_name, param_names) - def build_name_mapping(self): + def build_name_mapping(self, sname_to_tname=None): for (k, segment) in self._segments.items(): for (i, layer) in segment.layers.items(): for param in layer.params.items(): (param_name, tensor_name) = param # map to a new name n_name = self._rename_mgr.get_new_param_name(layer.name, tensor_name) + if sname_to_tname is not None: + if param_name in sname_to_tname.keys(): + self._wname_to_rname[param_name] = sname_to_tname[param_name] # logger.info(f"{param_name} {tensor_name}=>{n_name}") self._param_to_tname[param_name] = (tensor_name, n_name) def map_name(self, param_name, t_name): assert param_name in self._param_to_tname tensor_name, n_name = self._param_to_tname[param_name] + if param_name in self._wname_to_rname: + n_name = self._wname_to_rname[param_name] assert tensor_name == t_name return n_name @@ -261,6 +285,11 @@ def __init__( self._index_layers() stage_segments = self._segment() + if has_register_sname_to_tname_func(): + self._sname_to_tname = get_sname_to_tname_func()(pp_model) + else: + self._sname_to_tname = None + for (i, stage_seg) in enumerate(stage_segments): pipe_stage = PipeLineStage() self._stages.append(pipe_stage) @@ -275,7 +304,7 @@ def __init__( self._layer_name_to_stage[layer_name] = i for stage in self._stages: - stage.build_name_mapping() + stage.build_name_mapping(self._sname_to_tname) def _index_layers(self): for layer_name in self._param_names_by_layer.keys(): From 2f446bbd8da67ba6a77168b6d4a706dcf874cd0f Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Thu, 10 Oct 2024 10:54:00 +0800 Subject: [PATCH 2/3] support best unbalaced pp scheduler (#9235) --- paddlenlp/trainer/training_args.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 2f4f6a04a005..ffd90668c161 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1120,6 +1120,7 @@ def split_parallel_config(parallel_config): "enable_clear_every_step_cache", "enable_overlap_p2p_comm", "disable_batch_p2p_comm", + "best_unbalanced_scheduler", ]: raise ValueError( f"Found unknown pipeline mode config {x}, accpet config is disable_p2p_cache_shape, disable_partial_send_recv." @@ -1158,6 +1159,7 @@ def split_parallel_config(parallel_config): "overlap_p2p_comm": "enable_overlap_p2p_comm" in pipeline_parallel_config, "clear_every_step_cache": "enable_clear_every_step_cache" in pipeline_parallel_config, "use_batch_p2p_comm": "disable_batch_p2p_comm" not in pipeline_parallel_config, + "best_unbalanced_scheduler": "best_unbalanced_scheduler" in pipeline_parallel_config, } if dygraph_pp_configs["dp_comm_overlap"]: raise ValueError("overlap has accuracy issue") # TODO: fix `overalap` + `delay_scale` issue From 226fdada7f2bcb34ed8d0cb0f23aa1a9d937fe1d Mon Sep 17 00:00:00 2001 From: Meiyim Date: Tue, 24 Sep 2024 16:12:25 +0800 Subject: [PATCH 3/3] remove pp hack (#9189) --- paddlenlp/trainer/trainer.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index ab377acf644a..d09dada4a844 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2270,13 +2270,6 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle self._pp_data_buffer = [] model.train() - # hack pipeline-layers - # since the pipeline layer will check input is valid every iter. - # in same case, for example, batch size warmup, we need dynamic change gradient_accumulation_steps to implement. - config_backup = model.micro_batch_size, model.accumulate_steps - model.micro_batch_size = self.args.per_device_train_batch_size - model.accumulate_steps = self.args.gradient_accumulation_steps - if model._dp_comm_overlap or model._sharding_comm_overlap: for _, buffers in model._chunk_2_comm_buffers.items(): for buffer in buffers: @@ -2291,8 +2284,6 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle with self.autocast_smart_context_manager(): loss = model.forward_backward_pipeline(inputs, self.scaler if self.do_grad_scaling else None) - model.micro_batch_size, model.accumulate_steps = config_backup - return loss.detach() def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Optional[bool] = False):