Skip to content

Commit 0f45875

Browse files
ZHUIlugimzzz
andauthored
Revert "update" (#8389)
* Revert "update (#8359)" This reverts commit ae0bea9. * Update modeling.py * fix --------- Co-authored-by: lugimzzz <zhenglujing@baidu.com>
1 parent 829e7f0 commit 0f45875

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

paddlenlp/transformers/llama/modeling.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,11 @@ def swiglu(x, y=None):
9696
"LlamaForCausalLM",
9797
"LlamaPretrainingCriterion",
9898
]
99-
global npu_is_casual
99+
100+
100101
npu_is_casual = False
101102

103+
102104
def _get_interleave(n):
103105
def _get_interleave_power_of_2(n):
104106
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
@@ -213,7 +215,7 @@ def scaled_dot_product_attention(
213215
):
214216
bsz, q_len, num_heads, head_dim = query_states.shape
215217
_, kv_seq_len, _, _ = value_states.shape
216-
global npu_is_casual
218+
217219
if config.use_flash_attention and flash_attention:
218220
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
219221
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
@@ -1119,7 +1121,6 @@ def __init__(self, config, layerwise_recompute: bool = False):
11191121
self.layerwise_recompute = layerwise_recompute
11201122
self.recompute_granularity = config.recompute_granularity
11211123

1122-
11231124
def forward(
11241125
self,
11251126
hidden_states: paddle.Tensor,
@@ -1613,14 +1614,12 @@ def forward(
16131614
attention_mask = self._prepare_decoder_attention_mask(
16141615
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
16151616
) # [bs, 1, seq_len, seq_len]
1616-
global npu_is_casual
16171617
if self.config.use_flash_attention:
16181618
is_casual = is_casual_mask(attention_mask)
16191619
if get_env_device() != "npu":
16201620
if is_casual and alibi is None:
16211621
attention_mask = None
16221622
else:
1623-
npu_is_casual = is_casual
16241623
attention_mask = attention_mask.astype("bool")
16251624
hidden_states = inputs_embeds
16261625
# decoder layers
@@ -1728,10 +1727,12 @@ def forward(self, prediction_scores, masked_lm_labels):
17281727
# skip ignore_index which loss == 0
17291728
# masked_lm_loss = masked_lm_loss[masked_lm_loss > 0]
17301729
# loss = paddle.mean(masked_lm_loss)
1731-
binary_sequence = paddle.where(masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss))
1730+
binary_sequence = paddle.where(
1731+
masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss)
1732+
)
17321733
sum_ = paddle.sum(binary_sequence)
1733-
loss = 0 if sum_ == 0 else paddle.sum(masked_lm_loss * binary_sequence) / sum_
1734-
1734+
loss = 0 if sum_ == 0 else paddle.sum(masked_lm_loss * binary_sequence) / sum_
1735+
17351736
return loss
17361737

17371738

0 commit comments

Comments
 (0)