Skip to content

Add Sharding V1 broadcast and V2 allgather overlap optimize #8499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 31, 2024
3 changes: 3 additions & 0 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@
)
from paddlenlp.utils.log import logger

# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
os.environ["USE_CASUAL_MASK"] = "False"


def add_start_docstrings(*docstr):
def docstring_decorator(fn):
Expand Down
3 changes: 3 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device

# Pretaining Environment Variables to support sharding stage1 overlap optimization.
os.environ["USE_CASUAL_MASK"] = "True"


def add_start_docstrings(*docstr):
def docstring_decorator(fn):
Expand Down
16 changes: 15 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,7 +1892,6 @@ def get_expected_keys(inputs, keys):
optimizer._set_broadcast_overlap(True, model)

self.optimizer = optimizer

# pure tesnor parallel mode, no pipeline_parallel, no sharding.
if (
not in_pipeline_parallel_mode
Expand All @@ -1908,6 +1907,21 @@ def get_expected_keys(inputs, keys):
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

# stage1 has v1 and v2 version
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
if "split_param" in self.args.sharding_parallel_config:
if (
hasattr(self.optimizer, "_set_all_gather_overlap_forward")
and "enable_stage1_allgather_overlap" in self.args.sharding_parallel_config
):
self.optimizer._set_all_gather_overlap_forward(True, model)
else:
if (
hasattr(self.optimizer, "_set_broadcast_overlap")
and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config
):
self.optimizer._set_broadcast_overlap(True, model)

return model

def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]:
Expand Down
39 changes: 37 additions & 2 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ class TrainingArguments:
enable_stage1_tensor_fusion, fuse small tensors into big tensor chunks to accelerate communications, may increase memory occupation
enable_stage1_overlap, fuse small tensors into big tensor chunks to accelerate communications and do communication overlap with backward computation, may harm the backward speed
enable_stage2_overlap, overlap stage2 NCCL communication with computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap and no other sync could be called during the training for broadcast overlap.
enable_stage1_broadcast_overlap, overlap stage1 V1 broadcast with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap forward compute and no other sync could be called during the training for broadcast overlap.
enable_stage1_allgather_overlap, overlap stage1 V2 allgather with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for allgather overlap forward compute and no other sync could be called during the training for allgather overlap.
disable_stage1_reduce_avg, replace reduce_avg with original reduce_sum+scale in stage1, which can be used for accuracy verification.
recompute (`bool`, *optional*, defaults to `False`):
Recompute the forward pass to calculate gradients. Used for saving memory.
Expand Down Expand Up @@ -647,7 +649,9 @@ class TrainingArguments:
"enable_stage1_tensor_fusion, fuse small tensors into big tensor chunks to accelerate communications, may increase memory occupation\n"
"enable_stage1_overlap, fuse small tensors into big tensor chunks to accelerate communications and do communication overlap with backward computation, may harm the backward speed\n"
"disable_stage1_reduce_avg, replace reduce_avg with original reduce_sum+scale in stage1, which can be used for accuracy verification.\n"
"enable_stage2_overlap, overlap stage2 NCCL communication with computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap and no other sync could be called during the training for broadcast overlap"
"enable_stage2_overlap, overlap stage2 NCCL communication with computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap and no other sync could be called during the training for broadcast overlap\n"
"enable_stage1_broadcast_overlap, overlap stage1 V1 broadcast with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap forward compute and no other sync could be called during the training for broadcast overlap.\n"
"enable_stage1_allgather_overlap, overlap stage1 V2 allgather with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for allgather overlap forward compute and no other sync could be called during the training for allgather overlap."
)
},
)
Expand Down Expand Up @@ -1186,10 +1190,12 @@ def is_segment_parallel_supported():
"enable_stage2_overlap",
"split_param",
"disable_stage1_reduce_avg",
"enable_stage1_broadcast_overlap",
"enable_stage1_allgather_overlap",
]:
raise ValueError(
f"Found unknown pipeline mode config {x}, "
f"accpet config is enable_stage1_tensor_fusion, enable_stage1_overlap, enable_stage2_overlap."
f"accpet config is enable_stage1_tensor_fusion, enable_stage1_overlap, enable_stage2_overlap, split_param, disable_stage1_reduce_avg, enable_stage1_broadcast_overlap, enable_stage1_allgather_overlap."
)
if "disable_stage1_reduce_avg" in sharding_parallel_config:
assert self.sharding == [
Expand Down Expand Up @@ -1235,6 +1241,35 @@ def is_segment_parallel_supported():
"The logging_steps should be greater than 1 for stage2 overlap, "
f"but got logging_steps={self.logging_steps}."
)
if "enable_stage1_broadcast_overlap" in sharding_parallel_config:
assert (
ShardingOption.SHARD_OP in self.sharding
), f"enable_stage1_broadcast_overlap expects sharding=stage1, but got {self.sharding}."

assert (
"split_param" not in sharding_parallel_config
), "split_param should not be set when enable_stage1_broadcast_overlap."
use_casual_mask = os.getenv("USE_CASUAL_MASK", "False")
assert use_casual_mask, "enable_stage1_broadcast_overlap requires USE_CASUAL_MASK=True."
assert self.logging_steps > 1, (
"The logging_steps should be greater than 1 for stage1_broadcast_overlap, "
f"but got logging_steps={self.logging_steps}."
)
if "enable_stage1_allgather_overlap" in sharding_parallel_config:
assert (
ShardingOption.SHARD_OP in self.sharding
), f"enable_stage1_allgather_overlap expects sharding=stage1, but got {self.sharding}."

assert (
"split_param" in sharding_parallel_config
), "split_param should be set when enable_stage1_allgather_overlap."
use_casual_mask = os.getenv("USE_CASUAL_MASK", "False")
assert use_casual_mask, "enable_stage1_allgather_overlap requires USE_CASUAL_MASK=True."
assert self.logging_steps > 1, (
"The logging_steps should be greater than 1 for enable_stage1_allgather_overlap, "
f"but got logging_steps={self.logging_steps}."
)

fleet.init(is_collective=True, strategy=strategy)
logger.info(strategy)

Expand Down
28 changes: 21 additions & 7 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,15 @@ def _get_interleave_power_of_2(n):
)


def get_use_casual_mask():
"""Get the value of the 'USE_CASUAL_MASK' environment variable."""
return os.getenv("USE_CASUAL_MASK", "False") == "True"


def build_alibi_tensor(
bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1
) -> Tensor:
attention_mask = bool_attention_mask.astype("float32")
batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1]
batch_size, seq_length = bool_attention_mask.shape[0], bool_attention_mask.shape[-1]
slopes = paddle.to_tensor(_get_interleave(num_heads), dtype="float32")
alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(axis=[0, 1]).expand(
[num_heads, -1, -1]
Expand Down Expand Up @@ -307,7 +311,7 @@ def is_casual_mask(attention_mask):

def _make_causal_mask(input_ids_shape, past_key_values_length):
"""
Make causal mask used for self-attention
Make casual mask used for self-attention
"""
batch_size, target_length = input_ids_shape # target_length: seq_len

Expand Down Expand Up @@ -1543,12 +1547,22 @@ def forward(
if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
use_casual_mask = get_use_casual_mask()

if use_casual_mask:
attention_mask = None
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]

is_casual = False

if self.config.use_flash_attention and get_env_device() != "gcu":
is_casual = is_casual_mask(attention_mask)
if use_casual_mask:
is_casual = True
else:
is_casual = is_casual_mask(attention_mask)
if get_env_device() != "npu":
if is_casual and alibi is None:
attention_mask = None
Expand Down
Loading