@@ -96,9 +96,11 @@ def swiglu(x, y=None):
96
96
"LlamaForCausalLM" ,
97
97
"LlamaPretrainingCriterion" ,
98
98
]
99
- global npu_is_casual
99
+
100
+
100
101
npu_is_casual = False
101
102
103
+
102
104
def _get_interleave (n ):
103
105
def _get_interleave_power_of_2 (n ):
104
106
start = 2 ** (- (2 ** - (math .log2 (n ) - 3 )))
@@ -213,7 +215,7 @@ def scaled_dot_product_attention(
213
215
):
214
216
bsz , q_len , num_heads , head_dim = query_states .shape
215
217
_ , kv_seq_len , _ , _ = value_states .shape
216
- global npu_is_casual
218
+
217
219
if config .use_flash_attention and flash_attention :
218
220
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
219
221
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
@@ -1119,7 +1121,6 @@ def __init__(self, config, layerwise_recompute: bool = False):
1119
1121
self .layerwise_recompute = layerwise_recompute
1120
1122
self .recompute_granularity = config .recompute_granularity
1121
1123
1122
-
1123
1124
def forward (
1124
1125
self ,
1125
1126
hidden_states : paddle .Tensor ,
@@ -1613,14 +1614,12 @@ def forward(
1613
1614
attention_mask = self ._prepare_decoder_attention_mask (
1614
1615
attention_mask , (batch_size , seq_length ), cache_length , inputs_embeds .dtype
1615
1616
) # [bs, 1, seq_len, seq_len]
1616
- global npu_is_casual
1617
1617
if self .config .use_flash_attention :
1618
1618
is_casual = is_casual_mask (attention_mask )
1619
1619
if get_env_device () != "npu" :
1620
1620
if is_casual and alibi is None :
1621
1621
attention_mask = None
1622
1622
else :
1623
- npu_is_casual = is_casual
1624
1623
attention_mask = attention_mask .astype ("bool" )
1625
1624
hidden_states = inputs_embeds
1626
1625
# decoder layers
@@ -1728,10 +1727,12 @@ def forward(self, prediction_scores, masked_lm_labels):
1728
1727
# skip ignore_index which loss == 0
1729
1728
# masked_lm_loss = masked_lm_loss[masked_lm_loss > 0]
1730
1729
# loss = paddle.mean(masked_lm_loss)
1731
- binary_sequence = paddle .where (masked_lm_loss > 0 , paddle .ones_like (masked_lm_loss ), paddle .zeros_like (masked_lm_loss ))
1730
+ binary_sequence = paddle .where (
1731
+ masked_lm_loss > 0 , paddle .ones_like (masked_lm_loss ), paddle .zeros_like (masked_lm_loss )
1732
+ )
1732
1733
sum_ = paddle .sum (binary_sequence )
1733
- loss = 0 if sum_ == 0 else paddle .sum (masked_lm_loss * binary_sequence ) / sum_
1734
-
1734
+ loss = 0 if sum_ == 0 else paddle .sum (masked_lm_loss * binary_sequence ) / sum_
1735
+
1735
1736
return loss
1736
1737
1737
1738
0 commit comments