@@ -193,12 +193,26 @@ class LlamaDecoderLayerPipe(LlamaDecoderLayer):
193
193
def forward (self , args ):
194
194
hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids , alibi = parse_args (args )
195
195
# we can't distinguish
196
- # hidden_states, attention_mask, position_ids or
197
- # hidden_states, attention_mask, alibi
198
-
199
- if self .config .alibi and alibi is None and position_ids is not None :
200
- alibi = position_ids
196
+ if self .config .alibi and alibi is None and position_ids is None and attn_mask_startend_row_indices is not None :
197
+ # hidden_states, attention_mask, alibi
198
+ alibi = attn_mask_startend_row_indices
201
199
position_ids = None
200
+ attn_mask_startend_row_indices = None
201
+ elif (
202
+ self .config .alibi
203
+ and alibi is None
204
+ and position_ids is not None
205
+ and attn_mask_startend_row_indices is not None
206
+ ):
207
+ # hidden_states, attention_mask, position_ids, alibi
208
+ alibi = position_ids
209
+ position_ids = attn_mask_startend_row_indices
210
+ attn_mask_startend_row_indices = None
211
+ elif not self .config .alibi and position_ids is None and attn_mask_startend_row_indices is not None :
212
+ # hidden_states, attention_mask, position_ids
213
+ position_ids = attn_mask_startend_row_indices
214
+ attn_mask_startend_row_indices = None
215
+ alibi = None
202
216
203
217
has_gradient = not hidden_states .stop_gradient
204
218
if self .enable_recompute and self .config .recompute_granularity == "full" and has_gradient :
0 commit comments