@@ -198,7 +198,7 @@ def __init__(self, config):
198
198
199
199
self .attn_dropout = nn .Dropout (config .attn_dropout_prob )
200
200
201
- def _attn (self , query , key , value , sequence_parallel = False , attention_mask = None ):
201
+ def _attn (self , query , key , value , attention_mask = None ):
202
202
# Support the flash attention and normal attention
203
203
bsz , q_len , num_heads , head_dim = query .shape
204
204
_ , kv_seq_len , _ , _ = value .shape
@@ -228,7 +228,7 @@ def _attn(self, query, key, value, sequence_parallel=False, attention_mask=None)
228
228
)
229
229
attn_weights = None
230
230
231
- if sequence_parallel :
231
+ if self . sequence_parallel :
232
232
attn_output = attn_output .reshape ([bsz * q_len , head_dim * num_heads ])
233
233
else :
234
234
attn_output = attn_output .reshape ([bsz , q_len , head_dim * num_heads ])
@@ -258,7 +258,7 @@ def _attn(self, query, key, value, sequence_parallel=False, attention_mask=None)
258
258
attn_output = paddle .matmul (attn_weights , value )
259
259
attn_output = attn_output .transpose ([0 , 2 , 1 , 3 ])
260
260
261
- if sequence_parallel :
261
+ if self . sequence_parallel :
262
262
attn_output = attn_output .reshape ([bsz * q_len , head_dim * num_heads ])
263
263
else :
264
264
attn_output = attn_output .reshape ([bsz , q_len , head_dim * num_heads ])
@@ -356,12 +356,11 @@ def forward(
356
356
query ,
357
357
key ,
358
358
value ,
359
- self .sequence_parallel ,
360
359
attention_mask ,
361
360
use_reentrant = self .config .recompute_use_reentrant ,
362
361
)
363
362
else :
364
- attn_output , attn_weight = self ._attn (query , key , value , self . sequence_parallel , attention_mask )
363
+ attn_output , attn_weight = self ._attn (query , key , value , attention_mask )
365
364
366
365
# if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim]
367
366
# else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism.
0 commit comments