Skip to content

Commit e71540b

Browse files
iosmersdeepllzheavyrain-lzy
authored
Add Sharding V1 broadcast and allgather overlap optimize (#8499)
* update * update is_casual_mask to use_casual_mask * update by environment * add constraint * add pretrain and finetune enviroment * update * update * Update finetune_generation.py update use_casual_mask env * update * lint code --------- Co-authored-by: zhengzhonghui <zhengzhonghui@baidu.com> Co-authored-by: lizhiyu <1528794076@qq.com>
1 parent ac7145d commit e71540b

File tree

5 files changed

+79
-10
lines changed

5 files changed

+79
-10
lines changed

llm/finetune_generation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@
5151
)
5252
from paddlenlp.utils.log import logger
5353

54+
# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
55+
os.environ["USE_CASUAL_MASK"] = "False"
56+
5457

5558
def add_start_docstrings(*docstr):
5659
def docstring_decorator(fn):

llm/run_pretrain.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@
4848
from paddlenlp.utils.log import logger
4949
from paddlenlp.utils.tools import get_env_device
5050

51+
# Pretaining Environment Variables to support sharding stage1 overlap optimization.
52+
os.environ["USE_CASUAL_MASK"] = "True"
53+
5154

5255
def add_start_docstrings(*docstr):
5356
def docstring_decorator(fn):

paddlenlp/trainer/trainer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1897,7 +1897,6 @@ def get_expected_keys(inputs, keys):
18971897
optimizer._set_broadcast_overlap(True, model)
18981898

18991899
self.optimizer = optimizer
1900-
19011900
# pure tesnor parallel mode, no pipeline_parallel, no sharding.
19021901
if (
19031902
not in_pipeline_parallel_mode
@@ -1913,6 +1912,21 @@ def get_expected_keys(inputs, keys):
19131912
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
19141913
self.optimizer = fleet.distributed_optimizer(self.optimizer)
19151914

1915+
# stage1 has v1 and v2 version
1916+
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
1917+
if "split_param" in self.args.sharding_parallel_config:
1918+
if (
1919+
hasattr(self.optimizer, "_set_all_gather_overlap_forward")
1920+
and "enable_stage1_allgather_overlap" in self.args.sharding_parallel_config
1921+
):
1922+
self.optimizer._set_all_gather_overlap_forward(True, model)
1923+
else:
1924+
if (
1925+
hasattr(self.optimizer, "_set_broadcast_overlap")
1926+
and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config
1927+
):
1928+
self.optimizer._set_broadcast_overlap(True, model)
1929+
19161930
return model
19171931

19181932
def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]:

paddlenlp/trainer/training_args.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ class TrainingArguments:
263263
enable_stage1_tensor_fusion, fuse small tensors into big tensor chunks to accelerate communications, may increase memory occupation
264264
enable_stage1_overlap, fuse small tensors into big tensor chunks to accelerate communications and do communication overlap with backward computation, may harm the backward speed
265265
enable_stage2_overlap, overlap stage2 NCCL communication with computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap and no other sync could be called during the training for broadcast overlap.
266+
enable_stage1_broadcast_overlap, overlap stage1 V1 broadcast with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap forward compute and no other sync could be called during the training for broadcast overlap.
267+
enable_stage1_allgather_overlap, overlap stage1 V2 allgather with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for allgather overlap forward compute and no other sync could be called during the training for allgather overlap.
266268
disable_stage1_reduce_avg, replace reduce_avg with original reduce_sum+scale in stage1, which can be used for accuracy verification.
267269
recompute (`bool`, *optional*, defaults to `False`):
268270
Recompute the forward pass to calculate gradients. Used for saving memory.
@@ -647,7 +649,9 @@ class TrainingArguments:
647649
"enable_stage1_tensor_fusion, fuse small tensors into big tensor chunks to accelerate communications, may increase memory occupation\n"
648650
"enable_stage1_overlap, fuse small tensors into big tensor chunks to accelerate communications and do communication overlap with backward computation, may harm the backward speed\n"
649651
"disable_stage1_reduce_avg, replace reduce_avg with original reduce_sum+scale in stage1, which can be used for accuracy verification.\n"
650-
"enable_stage2_overlap, overlap stage2 NCCL communication with computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap and no other sync could be called during the training for broadcast overlap"
652+
"enable_stage2_overlap, overlap stage2 NCCL communication with computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap and no other sync could be called during the training for broadcast overlap\n"
653+
"enable_stage1_broadcast_overlap, overlap stage1 V1 broadcast with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap forward compute and no other sync could be called during the training for broadcast overlap.\n"
654+
"enable_stage1_allgather_overlap, overlap stage1 V2 allgather with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for allgather overlap forward compute and no other sync could be called during the training for allgather overlap."
651655
)
652656
},
653657
)
@@ -1192,10 +1196,12 @@ def is_segment_parallel_supported():
11921196
"enable_stage2_overlap",
11931197
"split_param",
11941198
"disable_stage1_reduce_avg",
1199+
"enable_stage1_broadcast_overlap",
1200+
"enable_stage1_allgather_overlap",
11951201
]:
11961202
raise ValueError(
11971203
f"Found unknown pipeline mode config {x}, "
1198-
f"accpet config is enable_stage1_tensor_fusion, enable_stage1_overlap, enable_stage2_overlap."
1204+
f"accpet config is enable_stage1_tensor_fusion, enable_stage1_overlap, enable_stage2_overlap, split_param, disable_stage1_reduce_avg, enable_stage1_broadcast_overlap, enable_stage1_allgather_overlap."
11991205
)
12001206
if "disable_stage1_reduce_avg" in sharding_parallel_config:
12011207
assert self.sharding == [
@@ -1241,6 +1247,35 @@ def is_segment_parallel_supported():
12411247
"The logging_steps should be greater than 1 for stage2 overlap, "
12421248
f"but got logging_steps={self.logging_steps}."
12431249
)
1250+
if "enable_stage1_broadcast_overlap" in sharding_parallel_config:
1251+
assert (
1252+
ShardingOption.SHARD_OP in self.sharding
1253+
), f"enable_stage1_broadcast_overlap expects sharding=stage1, but got {self.sharding}."
1254+
1255+
assert (
1256+
"split_param" not in sharding_parallel_config
1257+
), "split_param should not be set when enable_stage1_broadcast_overlap."
1258+
use_casual_mask = os.getenv("USE_CASUAL_MASK", "False")
1259+
assert use_casual_mask, "enable_stage1_broadcast_overlap requires USE_CASUAL_MASK=True."
1260+
assert self.logging_steps > 1, (
1261+
"The logging_steps should be greater than 1 for stage1_broadcast_overlap, "
1262+
f"but got logging_steps={self.logging_steps}."
1263+
)
1264+
if "enable_stage1_allgather_overlap" in sharding_parallel_config:
1265+
assert (
1266+
ShardingOption.SHARD_OP in self.sharding
1267+
), f"enable_stage1_allgather_overlap expects sharding=stage1, but got {self.sharding}."
1268+
1269+
assert (
1270+
"split_param" in sharding_parallel_config
1271+
), "split_param should be set when enable_stage1_allgather_overlap."
1272+
use_casual_mask = os.getenv("USE_CASUAL_MASK", "False")
1273+
assert use_casual_mask, "enable_stage1_allgather_overlap requires USE_CASUAL_MASK=True."
1274+
assert self.logging_steps > 1, (
1275+
"The logging_steps should be greater than 1 for enable_stage1_allgather_overlap, "
1276+
f"but got logging_steps={self.logging_steps}."
1277+
)
1278+
12441279
fleet.init(is_collective=True, strategy=strategy)
12451280
logger.info(strategy)
12461281

paddlenlp/transformers/llama/modeling.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,15 @@ def _get_interleave_power_of_2(n):
115115
)
116116

117117

118+
def get_use_casual_mask():
119+
"""Get the value of the 'USE_CASUAL_MASK' environment variable."""
120+
return os.getenv("USE_CASUAL_MASK", "False") == "True"
121+
122+
118123
def build_alibi_tensor(
119124
bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1
120125
) -> Tensor:
121-
attention_mask = bool_attention_mask.astype("float32")
122-
batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1]
126+
batch_size, seq_length = bool_attention_mask.shape[0], bool_attention_mask.shape[-1]
123127
slopes = paddle.to_tensor(_get_interleave(num_heads), dtype="float32")
124128
alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(axis=[0, 1]).expand(
125129
[num_heads, -1, -1]
@@ -307,7 +311,7 @@ def is_casual_mask(attention_mask):
307311

308312
def _make_causal_mask(input_ids_shape, past_key_values_length):
309313
"""
310-
Make causal mask used for self-attention
314+
Make casual mask used for self-attention
311315
"""
312316
batch_size, target_length = input_ids_shape # target_length: seq_len
313317

@@ -1543,12 +1547,22 @@ def forward(
15431547
if position_ids is None:
15441548
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
15451549

1546-
attention_mask = self._prepare_decoder_attention_mask(
1547-
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
1548-
) # [bs, 1, seq_len, seq_len]
1550+
use_casual_mask = get_use_casual_mask()
1551+
1552+
if use_casual_mask:
1553+
attention_mask = None
1554+
else:
1555+
attention_mask = self._prepare_decoder_attention_mask(
1556+
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
1557+
) # [bs, 1, seq_len, seq_len]
1558+
15491559
is_casual = False
1560+
15501561
if self.config.use_flash_attention and get_env_device() != "gcu":
1551-
is_casual = is_casual_mask(attention_mask)
1562+
if use_casual_mask:
1563+
is_casual = True
1564+
else:
1565+
is_casual = is_casual_mask(attention_mask)
15521566
if get_env_device() != "npu":
15531567
if is_casual and alibi is None:
15541568
attention_mask = None

0 commit comments

Comments
 (0)