Skip to content

Commit 69be4db

Browse files
lugimzzzzjjlivein
andauthored
fix (#8668)
Co-authored-by: zhangjunjun04 <zhangjunjun04@baidu.com>
1 parent 80e7ef5 commit 69be4db

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

paddlenlp/transformers/llama/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1578,7 +1578,7 @@ def forward(
15781578
if position_ids is None:
15791579
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
15801580

1581-
use_casual_mask = get_use_casual_mask()
1581+
use_casual_mask = get_use_casual_mask() and not self.config.alibi
15821582

15831583
if use_casual_mask:
15841584
attention_mask = None

paddlenlp/transformers/llama/modeling_pp.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,26 @@ class LlamaDecoderLayerPipe(LlamaDecoderLayer):
193193
def forward(self, args):
194194
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids, alibi = parse_args(args)
195195
# we can't distinguish
196-
# hidden_states, attention_mask, position_ids or
197-
# hidden_states, attention_mask, alibi
198-
199-
if self.config.alibi and alibi is None and position_ids is not None:
200-
alibi = position_ids
196+
if self.config.alibi and alibi is None and position_ids is None and attn_mask_startend_row_indices is not None:
197+
# hidden_states, attention_mask, alibi
198+
alibi = attn_mask_startend_row_indices
201199
position_ids = None
200+
attn_mask_startend_row_indices = None
201+
elif (
202+
self.config.alibi
203+
and alibi is None
204+
and position_ids is not None
205+
and attn_mask_startend_row_indices is not None
206+
):
207+
# hidden_states, attention_mask, position_ids, alibi
208+
alibi = position_ids
209+
position_ids = attn_mask_startend_row_indices
210+
attn_mask_startend_row_indices = None
211+
elif not self.config.alibi and position_ids is None and attn_mask_startend_row_indices is not None:
212+
# hidden_states, attention_mask, position_ids
213+
position_ids = attn_mask_startend_row_indices
214+
attn_mask_startend_row_indices = None
215+
alibi = None
202216

203217
has_gradient = not hidden_states.stop_gradient
204218
if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:

0 commit comments

Comments
 (0)