Skip to content

Commit cf69b0b

Browse files
committed
[LLM] support sparse attention for LLAMA
1 parent 5bdf751 commit cf69b0b

File tree

3 files changed

+95
-16
lines changed

3 files changed

+95
-16
lines changed

paddlenlp/transformers/llama/fusion_ops.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def fusion_flash_attention(
149149
sequence_parallel=False,
150150
reshard_layer=None,
151151
npu_is_casual=False,
152+
attn_mask_start_row_indices=None,
152153
):
153154
bsz, q_len, num_heads, head_dim = query_states.shape
154155
_, kv_seq_len, _, _ = value_states.shape
@@ -208,13 +209,23 @@ def fusion_flash_attention(
208209
is_causal=True,
209210
)
210211
else:
211-
attn_output = F.scaled_dot_product_attention(
212-
query_states,
213-
key_states,
214-
value_states,
215-
attn_mask=attention_mask,
216-
is_causal=attention_mask is None,
217-
)
212+
if attn_mask_start_row_indices is not None:
213+
assert alibi is None, "flash_attention_with_sparse_mask not support alibi"
214+
attn_output = F.flash_attention_with_sparse_mask(
215+
query_states,
216+
key_states,
217+
value_states,
218+
attn_mask_start_row_indices=attn_mask_start_row_indices,
219+
is_causal=True,
220+
)
221+
else:
222+
attn_output = F.scaled_dot_product_attention(
223+
query_states,
224+
key_states,
225+
value_states,
226+
attn_mask=attention_mask,
227+
is_causal=attention_mask is None,
228+
)
218229
attn_weights = None
219230

220231
if reshard_layer is not None:

paddlenlp/transformers/llama/modeling.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def scaled_dot_product_attention(
212212
attention_mask,
213213
output_attentions,
214214
alibi=None,
215+
attn_mask_start_row_indices=None,
215216
sequence_parallel=False,
216217
reshard_layer=None,
217218
npu_is_casual=False,
@@ -228,6 +229,7 @@ def scaled_dot_product_attention(
228229
attention_mask,
229230
output_attentions,
230231
alibi,
232+
attn_mask_start_row_indices,
231233
sequence_parallel,
232234
reshard_layer,
233235
npu_is_casual,
@@ -815,6 +817,7 @@ def forward(
815817
output_attentions: bool = False,
816818
use_cache: bool = False,
817819
alibi: Optional[paddle.Tensor] = None,
820+
attn_mask_start_row_indices: Optional[paddle.Tensor] = None,
818821
npu_is_casual: bool = False,
819822
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
820823
"""Input shape: Batch x Time x Channel"""
@@ -1013,6 +1016,7 @@ def forward(
10131016
attention_mask,
10141017
output_attentions,
10151018
alibi,
1019+
attn_mask_start_row_indices,
10161020
self.sequence_parallel,
10171021
reshard_layer=self.reshard_layer,
10181022
use_reentrant=self.config.recompute_use_reentrant,
@@ -1026,6 +1030,7 @@ def forward(
10261030
attention_mask,
10271031
output_attentions,
10281032
alibi,
1033+
attn_mask_start_row_indices,
10291034
self.sequence_parallel,
10301035
reshard_layer=self.reshard_layer,
10311036
npu_is_casual=npu_is_casual,
@@ -1081,6 +1086,7 @@ def forward(
10811086
past_key_value: Optional[Tuple[paddle.Tensor]] = None,
10821087
use_cache: Optional[bool] = False,
10831088
alibi: Optional[paddle.Tensor] = None,
1089+
attn_mask_start_row_indices: Optional[paddle.Tensor] = None,
10841090
npu_is_casual: bool = False,
10851091
) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]:
10861092
"""
@@ -1118,6 +1124,7 @@ def forward(
11181124
output_attentions,
11191125
use_cache,
11201126
alibi,
1127+
attn_mask_start_row_indices,
11211128
use_reentrant=self.config.recompute_use_reentrant,
11221129
)
11231130
else:
@@ -1129,6 +1136,7 @@ def forward(
11291136
output_attentions,
11301137
use_cache,
11311138
alibi,
1139+
attn_mask_start_row_indices=attn_mask_start_row_indices,
11321140
npu_is_casual=npu_is_casual,
11331141
)
11341142

@@ -1458,6 +1466,7 @@ def recompute_training_full(
14581466
past_key_value: Tensor,
14591467
use_cache: bool,
14601468
alibi=None,
1469+
attn_mask_start_row_indices=None,
14611470
):
14621471
def create_custom_forward(module):
14631472
def custom_forward(*inputs):
@@ -1474,6 +1483,7 @@ def custom_forward(*inputs):
14741483
past_key_value,
14751484
use_cache,
14761485
alibi,
1486+
attn_mask_start_row_indices,
14771487
use_reentrant=self.config.recompute_use_reentrant,
14781488
)
14791489

@@ -1490,6 +1500,7 @@ def forward(
14901500
output_attentions=False,
14911501
output_hidden_states=None,
14921502
return_dict=False,
1503+
attn_mask_start_row_indices=None,
14931504
**kwargs,
14941505
):
14951506
if self.sequence_parallel and use_cache:
@@ -1536,10 +1547,10 @@ def forward(
15361547
if self.config.context_parallel_degree > 1 and (attention_mask is not None or self.config.alibi):
15371548
raise NotImplementedError("Ring FlashAttention dosen't support attention_mask or alibi")
15381549
# embed positions
1539-
if attention_mask is None:
1550+
if attn_mask_start_row_indices is None and attention_mask is None:
15401551
# [bs, seq_len]
15411552
attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
1542-
if self.config.alibi:
1553+
if attn_mask_start_row_indices is None and self.config.alibi:
15431554
if self.config.use_long_sequence_strategies:
15441555
alibi_layer = LongSequenceStrategies.build_long_sequence_strategy(
15451556
self.config.long_sequence_strategy_type,
@@ -1570,14 +1581,14 @@ def forward(
15701581

15711582
if use_casual_mask:
15721583
attention_mask = None
1573-
else:
1584+
elif attn_mask_start_row_indices is None:
15741585
attention_mask = self._prepare_decoder_attention_mask(
15751586
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
15761587
) # [bs, 1, seq_len, seq_len]
15771588

15781589
is_casual = False
15791590

1580-
if self.config.use_flash_attention and get_env_device() != "gcu":
1591+
if attn_mask_start_row_indices is None and self.config.use_flash_attention and get_env_device() != "gcu":
15811592
if use_casual_mask:
15821593
is_casual = True
15831594
else:
@@ -1614,6 +1625,7 @@ def forward(
16141625
past_key_value,
16151626
use_cache,
16161627
alibi=alibi,
1628+
attn_mask_start_row_indices=attn_mask_start_row_indices,
16171629
)
16181630
else:
16191631
layer_outputs = decoder_layer(
@@ -1624,6 +1636,7 @@ def forward(
16241636
past_key_value,
16251637
use_cache,
16261638
alibi=alibi,
1639+
attn_mask_start_row_indices=attn_mask_start_row_indices,
16271640
npu_is_casual=is_casual,
16281641
)
16291642

@@ -1881,6 +1894,7 @@ def forward(
18811894
output_attentions=None,
18821895
output_hidden_states=None,
18831896
return_dict=None,
1897+
attn_mask_start_row_indices=None,
18841898
):
18851899
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
18861900
output_hidden_states = (
@@ -1897,6 +1911,7 @@ def forward(
18971911
output_attentions=output_attentions,
18981912
output_hidden_states=output_hidden_states,
18991913
return_dict=return_dict,
1914+
attn_mask_start_row_indices=attn_mask_start_row_indices,
19001915
)
19011916

19021917
hidden_states = outputs[0] # [bs, seq_len, dim]

paddlenlp/transformers/llama/modeling_pp.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections import OrderedDict
16+
1517
import paddle
1618
import paddle.distributed.fleet as fleet
1719
import paddle.nn as nn
@@ -149,7 +151,7 @@ def forward(self, args):
149151
alibi = alibi.reshape([batch_size * self.config.num_attention_heads, 1, seq_length])
150152
alibi.stop_gradient = True
151153

152-
if attention_mask is not None:
154+
if attention_mask is not None and attention_mask.dtype != paddle.int32:
153155
attention_mask = LlamaModel._prepare_decoder_attention_mask(
154156
attention_mask, (batch_size, seq_length), 0, input_embeds.dtype
155157
)
@@ -175,23 +177,44 @@ def forward(self, args):
175177
# we can't distinguish
176178
# hidden_states, attention_mask, position_ids or
177179
# hidden_states, attention_mask, alibi
180+
181+
if attention_mask is None:
182+
attn_mask_start_row_indices = None
183+
elif attention_mask.dtype == paddle.int32:
184+
attn_mask_start_row_indices = attention_mask
185+
else:
186+
attn_mask_start_row_indices = None
187+
178188
if self.config.alibi and alibi is None and position_ids is not None:
179189
alibi = position_ids
180190
position_ids = None
181191

182192
has_gradient = not hidden_states.stop_gradient
183193
if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:
184-
if attention_mask is not None or alibi is not None:
194+
if attention_mask is not None or alibi is not None or attn_mask_start_row_indices is not None:
185195
hidden_states = recompute(
186-
super().forward, hidden_states, attention_mask=attention_mask, alibi=alibi, use_reentrant=False
196+
super().forward,
197+
hidden_states,
198+
attention_mask=attention_mask,
199+
alibi=alibi,
200+
attn_mask_start_row_indices=attn_mask_start_row_indices,
201+
use_reentrant=False,
187202
)
188203
else:
189204
# for pretrain
190205
hidden_states = recompute(
191-
super().forward, hidden_states, use_reentrant=self.config.recompute_use_reentrant
206+
super().forward,
207+
hidden_states,
208+
attn_mask_start_row_indices=attn_mask_start_row_indices,
209+
use_reentrant=self.config.recompute_use_reentrant,
192210
)
193211
else:
194-
hidden_states = super().forward(hidden_states, attention_mask=attention_mask, alibi=alibi)
212+
hidden_states = super().forward(
213+
hidden_states,
214+
attention_mask=attention_mask,
215+
alibi=alibi,
216+
attn_mask_start_row_indices=attn_mask_start_row_indices,
217+
)
195218

196219
return return_args(hidden_states, attention_mask, position_ids, alibi)
197220

@@ -222,6 +245,36 @@ class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
222245

223246
# DONOT Add base_model_prefix !!!!
224247

248+
@classmethod
249+
def _prepare_pipeline_inputs_func(cls, inputs):
250+
first_stage_keys = ["input_ids", "attn_mask_start_row_indices", "position_ids"]
251+
if type(inputs) is dict or type(inputs) is OrderedDict:
252+
if "attention_mask" in inputs:
253+
first_stage_keys = ["input_ids", "attention_mask", "position_ids"]
254+
else: # inputs is list
255+
if "attention_mask" in inputs[0]:
256+
first_stage_keys = ["input_ids", "attention_mask", "position_ids"]
257+
last_stage_keys = ["labels"]
258+
259+
def get_expected_keys(inputs, keys):
260+
ret = tuple([inputs.pop(k) for k in keys if k in inputs])
261+
if len(ret) == 1:
262+
ret = ret[0]
263+
return ret
264+
265+
if type(inputs) is dict or type(inputs) is OrderedDict:
266+
return [
267+
get_expected_keys(inputs, first_stage_keys),
268+
get_expected_keys(inputs, last_stage_keys),
269+
]
270+
271+
keys = list(inputs[0].keys())
272+
inputs_batch = {key: [data.pop(key) for data in inputs] for key in keys}
273+
return [
274+
get_expected_keys(inputs_batch, first_stage_keys),
275+
get_expected_keys(inputs_batch, last_stage_keys),
276+
]
277+
225278
def __init__(self, config):
226279
self.config = config
227280

0 commit comments

Comments
 (0)