Skip to content

Commit 71e2b64

Browse files
committed
fix some typo
1 parent 10971e8 commit 71e2b64

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

paddlenlp/transformers/qwen/modeling.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def __init__(self, config):
198198

199199
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
200200

201-
def _attn(self, query, key, value, sequence_parallel=False, attention_mask=None):
201+
def _attn(self, query, key, value, attention_mask=None):
202202
# Support the flash attention and normal attention
203203
bsz, q_len, num_heads, head_dim = query.shape
204204
_, kv_seq_len, _, _ = value.shape
@@ -228,7 +228,7 @@ def _attn(self, query, key, value, sequence_parallel=False, attention_mask=None)
228228
)
229229
attn_weights = None
230230

231-
if sequence_parallel:
231+
if self.sequence_parallel:
232232
attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
233233
else:
234234
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
@@ -258,7 +258,7 @@ def _attn(self, query, key, value, sequence_parallel=False, attention_mask=None)
258258
attn_output = paddle.matmul(attn_weights, value)
259259
attn_output = attn_output.transpose([0, 2, 1, 3])
260260

261-
if sequence_parallel:
261+
if self.sequence_parallel:
262262
attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
263263
else:
264264
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
@@ -356,12 +356,11 @@ def forward(
356356
query,
357357
key,
358358
value,
359-
self.sequence_parallel,
360359
attention_mask,
361360
use_reentrant=self.config.recompute_use_reentrant,
362361
)
363362
else:
364-
attn_output, attn_weight = self._attn(query, key, value, self.sequence_parallel, attention_mask)
363+
attn_output, attn_weight = self._attn(query, key, value, attention_mask)
365364

366365
# if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim]
367366
# else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism.

0 commit comments

Comments
 (0)