Skip to content

Commit 7a53d1b

Browse files
committed
[LLM] fix bug when loss is None in llama modeling.py
1 parent daf2f3a commit 7a53d1b

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

paddlenlp/transformers/llama/modeling.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,8 +1645,11 @@ def forward(self, prediction_scores, masked_lm_labels):
16451645
binary_sequence = paddle.where(
16461646
masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss)
16471647
)
1648-
sum_ = paddle.sum(binary_sequence)
1649-
loss = 0 if sum_ == 0 else paddle.sum(masked_lm_loss * binary_sequence) / sum_
1648+
count = paddle.sum(binary_sequence)
1649+
if count == 0:
1650+
loss = paddle.sum(masked_lm_loss * binary_sequence)
1651+
else:
1652+
loss = paddle.sum(masked_lm_loss * binary_sequence) / count
16501653

16511654
return loss
16521655

0 commit comments

Comments
 (0)