@@ -212,6 +212,7 @@ def scaled_dot_product_attention(
212
212
attention_mask ,
213
213
output_attentions ,
214
214
alibi = None ,
215
+ attn_mask_start_row_indices = None ,
215
216
sequence_parallel = False ,
216
217
reshard_layer = None ,
217
218
npu_is_casual = False ,
@@ -228,6 +229,7 @@ def scaled_dot_product_attention(
228
229
attention_mask ,
229
230
output_attentions ,
230
231
alibi ,
232
+ attn_mask_start_row_indices ,
231
233
sequence_parallel ,
232
234
reshard_layer ,
233
235
npu_is_casual ,
@@ -815,6 +817,7 @@ def forward(
815
817
output_attentions : bool = False ,
816
818
use_cache : bool = False ,
817
819
alibi : Optional [paddle .Tensor ] = None ,
820
+ attn_mask_start_row_indices : Optional [paddle .Tensor ] = None ,
818
821
npu_is_casual : bool = False ,
819
822
) -> Tuple [paddle .Tensor , Optional [paddle .Tensor ], Optional [Tuple [paddle .Tensor ]]]:
820
823
"""Input shape: Batch x Time x Channel"""
@@ -1013,6 +1016,7 @@ def forward(
1013
1016
attention_mask ,
1014
1017
output_attentions ,
1015
1018
alibi ,
1019
+ attn_mask_start_row_indices ,
1016
1020
self .sequence_parallel ,
1017
1021
reshard_layer = self .reshard_layer ,
1018
1022
use_reentrant = self .config .recompute_use_reentrant ,
@@ -1026,6 +1030,7 @@ def forward(
1026
1030
attention_mask ,
1027
1031
output_attentions ,
1028
1032
alibi ,
1033
+ attn_mask_start_row_indices ,
1029
1034
self .sequence_parallel ,
1030
1035
reshard_layer = self .reshard_layer ,
1031
1036
npu_is_casual = npu_is_casual ,
@@ -1081,6 +1086,7 @@ def forward(
1081
1086
past_key_value : Optional [Tuple [paddle .Tensor ]] = None ,
1082
1087
use_cache : Optional [bool ] = False ,
1083
1088
alibi : Optional [paddle .Tensor ] = None ,
1089
+ attn_mask_start_row_indices : Optional [paddle .Tensor ] = None ,
1084
1090
npu_is_casual : bool = False ,
1085
1091
) -> Tuple [paddle .Tensor , Optional [Tuple [paddle .Tensor , paddle .Tensor ]]]:
1086
1092
"""
@@ -1118,6 +1124,7 @@ def forward(
1118
1124
output_attentions ,
1119
1125
use_cache ,
1120
1126
alibi ,
1127
+ attn_mask_start_row_indices ,
1121
1128
use_reentrant = self .config .recompute_use_reentrant ,
1122
1129
)
1123
1130
else :
@@ -1129,6 +1136,7 @@ def forward(
1129
1136
output_attentions ,
1130
1137
use_cache ,
1131
1138
alibi ,
1139
+ attn_mask_start_row_indices = attn_mask_start_row_indices ,
1132
1140
npu_is_casual = npu_is_casual ,
1133
1141
)
1134
1142
@@ -1458,6 +1466,7 @@ def recompute_training_full(
1458
1466
past_key_value : Tensor ,
1459
1467
use_cache : bool ,
1460
1468
alibi = None ,
1469
+ attn_mask_start_row_indices = None ,
1461
1470
):
1462
1471
def create_custom_forward (module ):
1463
1472
def custom_forward (* inputs ):
@@ -1474,6 +1483,7 @@ def custom_forward(*inputs):
1474
1483
past_key_value ,
1475
1484
use_cache ,
1476
1485
alibi ,
1486
+ attn_mask_start_row_indices ,
1477
1487
use_reentrant = self .config .recompute_use_reentrant ,
1478
1488
)
1479
1489
@@ -1490,6 +1500,7 @@ def forward(
1490
1500
output_attentions = False ,
1491
1501
output_hidden_states = None ,
1492
1502
return_dict = False ,
1503
+ attn_mask_start_row_indices = None ,
1493
1504
** kwargs ,
1494
1505
):
1495
1506
if self .sequence_parallel and use_cache :
@@ -1536,10 +1547,10 @@ def forward(
1536
1547
if self .config .context_parallel_degree > 1 and (attention_mask is not None or self .config .alibi ):
1537
1548
raise NotImplementedError ("Ring FlashAttention dosen't support attention_mask or alibi" )
1538
1549
# embed positions
1539
- if attention_mask is None :
1550
+ if attn_mask_start_row_indices is None and attention_mask is None :
1540
1551
# [bs, seq_len]
1541
1552
attention_mask = paddle .ones ((batch_size , seq_length_with_past ), dtype = paddle .bool )
1542
- if self .config .alibi :
1553
+ if attn_mask_start_row_indices is None and self .config .alibi :
1543
1554
if self .config .use_long_sequence_strategies :
1544
1555
alibi_layer = LongSequenceStrategies .build_long_sequence_strategy (
1545
1556
self .config .long_sequence_strategy_type ,
@@ -1570,14 +1581,14 @@ def forward(
1570
1581
1571
1582
if use_casual_mask :
1572
1583
attention_mask = None
1573
- else :
1584
+ elif attn_mask_start_row_indices is None :
1574
1585
attention_mask = self ._prepare_decoder_attention_mask (
1575
1586
attention_mask , (batch_size , seq_length ), cache_length , inputs_embeds .dtype
1576
1587
) # [bs, 1, seq_len, seq_len]
1577
1588
1578
1589
is_casual = False
1579
1590
1580
- if self .config .use_flash_attention and get_env_device () != "gcu" :
1591
+ if attn_mask_start_row_indices is None and self .config .use_flash_attention and get_env_device () != "gcu" :
1581
1592
if use_casual_mask :
1582
1593
is_casual = True
1583
1594
else :
@@ -1614,6 +1625,7 @@ def forward(
1614
1625
past_key_value ,
1615
1626
use_cache ,
1616
1627
alibi = alibi ,
1628
+ attn_mask_start_row_indices = attn_mask_start_row_indices ,
1617
1629
)
1618
1630
else :
1619
1631
layer_outputs = decoder_layer (
@@ -1624,6 +1636,7 @@ def forward(
1624
1636
past_key_value ,
1625
1637
use_cache ,
1626
1638
alibi = alibi ,
1639
+ attn_mask_start_row_indices = attn_mask_start_row_indices ,
1627
1640
npu_is_casual = is_casual ,
1628
1641
)
1629
1642
@@ -1881,6 +1894,7 @@ def forward(
1881
1894
output_attentions = None ,
1882
1895
output_hidden_states = None ,
1883
1896
return_dict = None ,
1897
+ attn_mask_start_row_indices = None ,
1884
1898
):
1885
1899
output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
1886
1900
output_hidden_states = (
@@ -1897,6 +1911,7 @@ def forward(
1897
1911
output_attentions = output_attentions ,
1898
1912
output_hidden_states = output_hidden_states ,
1899
1913
return_dict = return_dict ,
1914
+ attn_mask_start_row_indices = attn_mask_start_row_indices ,
1900
1915
)
1901
1916
1902
1917
hidden_states = outputs [0 ] # [bs, seq_len, dim]
0 commit comments