Skip to content

[Accuracy diff No.127] Fix accuracy diff for paddle.nn.functional.sigmoid_focal_loss API #73430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 18 additions & 71 deletions python/paddle/nn/functional/loss.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这看上去直接将 C 实现改为 python 实现了😶‍🌫️😶‍🌫️,这不太符合 paddle 库的贡献标准,破坏了动静图的区分,改动太大

paddleapitest 只是个粗略的测试项目,一切以 paddle 实现为准~

请同学参考参考贡献文档🫡:https://www.paddlepaddle.org.cn/documentation/docs/zh/dev_guides/index_cn.html

Copy link
Contributor Author

@NKNaN NKNaN Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那就按 paddle 的实现修改 SigmoidFocalLossRule 中 torch 的转换规则?
(paddle 的实现确认没问题的话就可能还需要改一下 paddle 的文档里对 label 的描述, 应该必须是 0 或 1 才对)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考 API 文档:https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/nn/functional/sigmoid_focal_loss_cn.html#sigmoid-focal-loss

其中 label 为 Tensor 类型,其值可以取 [0, 1] 中的任意值,不存在 “label 取值为 0.0 或 1.0” 的说法~测试代码:

import paddle

# paddle.nn.functional.sigmoid_focal_loss(Tensor([270072, 80],"float32"), Tensor([270072, 80],"float32"), )

logit = paddle.randn([270072, 80], dtype="float32")
label = paddle.uniform([270072, 80], dtype="float32")

result = paddle.nn.functional.sigmoid_focal_loss(logit, label)
print(result)

但精度问题可能是 SigmoidFocalLossRule 写错了,也可能是 paddle 内核代码有问题,后者可以具体看内核代码是如何处理的:paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits_kernel.cu

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文档中的公式是
−label ∗ [alpha ∗ (1−sigmoid(logit))**gamma] * log(sigmoid(logit)) − (1−label ) ∗ [(1−alpha) ∗ sigmoid(logit)**gamma] log(1−sigmoid(logit))

我理解的现在 paddle 的实现是这样:
loss = -label * sigmoid(logit) - (1 - label) * log (1 - sigmoid(logit)) // _C_ops.sigmoid_cross_entropy_with_logits
// 这一步可以看作上面公式里除去 alpha 和 gamma 的项

pred = sigmoid(logit)
p_t = pred * label + (1 - pred) * (1 - label)
alpha_t = alpha * label + (1 - alpha) * (1 - label)
loss = alpha_t * loss = [alpha * label + (1 - alpha) * (1 - label)] * [-label * sigmoid(logit) - (1 - label) * log (1 - sigmoid(logit))]
// 从这里开始就可以看出如果 label 不是 0 或 1 那计算的结果就不是文档中公式计算的结果,后面的 gamma_t 也一样。

我的意思是这个 API 虽然可以输入 0-1 之间的label值,也可以得到一个计算结果,但这个计算结果并不是文档中公式的结果,只有当这个 API 输入 0或1 的 label 值时,结果才和文档公式的结果一致。(对应单测在修改之前也是只有 label 是 0或1 的case)

如果要把现在改的 paddle API 组合实现替换成相应的 _C_ops 也行。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我明白你的意思了,代码中的表达确实与文档不一致,文档的含义是:
$\text{loss} = -\text{Labels} \cdot \alpha \cdot (1 - \sigma(\text{Logit}))^{\gamma} \log(\sigma(\text{Logit})) - (1 - \text{Labels}) \cdot (1 - \alpha) \cdot \sigma(\text{Logit})^{\gamma} \log(1 - \sigma(\text{Logit}))$

而代码调制的结果是:
$\text{loss} = \alpha_t \cdot (1 - p_t)^{\gamma} \cdot \left[ -\text{Labels} \cdot \log(\text{pred}) - (1 - \text{Labels}) \cdot \log(1 - \text{pred}) \right]$
$\text{loss} = \left[ \text{Labels} \cdot \alpha + (1 - \text{Labels}) \cdot (1 - \alpha) \right] \cdot \left( 1 - \text{Labels} \cdot \text{pred} - (1 - \text{Labels}) \cdot (1 - \text{pred}) \right)^{\gamma} \cdot \left[ -\text{Labels} \cdot \log(\text{pred}) - (1 - \text{Labels}) \cdot \log(1 - \text{pred}) \right]$

这看上去确实与设计有所误差,但目前认为这是合理的表达,参考论文原文3.2节:https://arxiv.org/pdf/1708.02002

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,原文里第三节也规定了 y \in {+/- 1},要不就修改一下 PaddleAPITest/tester/api_config/config_analyzer.py 里 get_numpy_tensor 给这个 API 的输入,限定为 0.0 或 1.0?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以的,需要限制输入,论文和 paddle api 文档里都说了是 0/1 的标签,看了一下 torchvision: sigmoid_focal_loss 的实现也是同样的假设和类似的计算过程

Original file line number Diff line number Diff line change
Expand Up @@ -3358,81 +3358,28 @@ def sigmoid_focal_loss(
f"Expected zero or one dimension of normalizer in sigmoid_focal_loss but got {normalizer_dims}."
)

if in_dynamic_or_pir_mode():
place = _current_expected_place()
one = _C_ops.full(paddle.shape(logit), 1.0, logit.dtype, place)

loss = _C_ops.sigmoid_cross_entropy_with_logits(
logit, label, None, False, -100
)

pred = _C_ops.sigmoid(logit)

p_t = _C_ops.add(
_C_ops.multiply(pred, label),
_C_ops.multiply(
_C_ops.subtract(one, pred), _C_ops.subtract(one, label)
),
)

alpha = paddle.to_tensor(alpha, dtype=loss.dtype)
alpha_t = _C_ops.add(
_C_ops.multiply(alpha, label),
_C_ops.multiply(
_C_ops.subtract(one, alpha), _C_ops.subtract(one, label)
),
)
loss = _C_ops.multiply(alpha_t, loss)
pred = paddle.nn.functional.sigmoid(logit)

if in_dynamic_mode():
gamma = paddle.to_tensor(gamma, dtype=loss.dtype)
gamma_t = _C_ops.pow(_C_ops.subtract(one, p_t), gamma)
loss = _C_ops.multiply(gamma_t, loss)

if normalizer is not None:
loss = _C_ops.divide(loss, normalizer)

if reduction == "sum":
return _C_ops.sum(loss, [], None, False)
elif reduction == "mean":
return _C_ops.mean_all(loss)

return loss

else:
check_variable_and_dtype(
logit, 'logit', ['float32', 'float64'], 'sigmoid_focal_loss'
)
check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'sigmoid_focal_loss'
)

bce_name = None
if reduction == 'none' and normalizer is None:
bce_name = name
loss = paddle.nn.functional.binary_cross_entropy_with_logits(
logit, label, None, reduction='none', name=bce_name
)

pred = paddle.nn.functional.sigmoid(logit)
p_t = pred * label + (1 - pred) * (1 - label)

alpha_t = alpha * label + (1 - alpha) * (1 - label)
loss = paddle.multiply(alpha_t, loss)

gamma_t = paddle.pow((1 - p_t), gamma)
loss = paddle.multiply(gamma_t, loss)
positive_loss = (
-label * alpha * (paddle.pow(1.0 - pred, gamma)) * paddle.log(pred)
)
negative_loss = (
-(1.0 - label)
* (1.0 - alpha)
* paddle.pow(pred, gamma)
* paddle.log(1 - pred)
)
loss = positive_loss + negative_loss

if normalizer is not None:
normalizer_name = name if reduction == 'none' else None
loss = paddle.divide(loss, normalizer, name=normalizer_name)
if normalizer is not None:
loss = paddle.divide(loss, normalizer)

if reduction == 'mean':
loss = paddle.mean(loss, name=name)
elif reduction == 'sum':
loss = paddle.sum(loss, name=name)
if reduction == "sum":
return paddle.sum(loss)
elif reduction == "mean":
return paddle.mean(loss)

return loss
return loss


def multi_label_soft_margin_loss(
Expand Down
85 changes: 72 additions & 13 deletions test/legacy_test/test_sigmoid_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,21 +92,13 @@ def test_dygraph(
def calc_sigmoid_focal_loss(
logit_np, label_np, normalizer_np, alpha=0.25, gamma=2.0, reduction='sum'
):
loss = (
np.maximum(logit_np, 0)
- logit_np * label_np
+ np.log(1 + np.exp(-np.abs(logit_np)))
)

pred = 1 / (1 + np.exp(-logit_np))
p_t = pred * label_np + (1 - pred) * (1 - label_np)

if alpha is not None:
alpha_t = alpha * label_np + (1 - alpha) * (1 - label_np)
loss = alpha_t * loss

if gamma is not None:
loss = loss * ((1 - p_t) ** gamma)
positive_loss = -label_np * alpha * ((1.0 - pred) ** gamma) * np.log(pred)
negative_loss = (
-(1.0 - label_np) * (1.0 - alpha) * (pred**gamma) * np.log(1 - pred)
)
loss = positive_loss + negative_loss

if normalizer_np is not None:
loss = loss / normalizer_np
Expand Down Expand Up @@ -200,5 +192,72 @@ def test_SigmoidFocalLoss_error(self):
paddle.enable_static()


class TestSigmoidFocalLossFloatLabel(unittest.TestCase):

def test_SigmoidFocalLoss(self):
logit_np = np.random.uniform(0.1, 0.8, size=(2, 3, 4, 10)).astype(
np.float64
)
label_np = np.random.uniform(0, 1, size=(2, 3, 4, 10)).astype(
np.float64
)
normalizer_nps = [
np.asarray([np.sum(label_np > 0)], dtype=label_np.dtype),
None,
]
places = []
if (
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
in ['1', 'true', 'on']
or not base.core.is_compiled_with_cuda()
):
places.append(base.CPUPlace())
if base.core.is_compiled_with_cuda():
places.append(base.CUDAPlace(0))
reductions = ['sum', 'mean', 'none']
alphas = [0.25, 0.5]
gammas = [3, 0.0]
for place in places:
for reduction in reductions:
for alpha in alphas:
for gamma in gammas:
for normalizer_np in normalizer_nps:
(static_result,) = test_static(
place,
logit_np,
label_np,
normalizer_np,
alpha,
gamma,
reduction,
)
dy_result = test_dygraph(
place,
logit_np,
label_np,
normalizer_np,
alpha,
gamma,
reduction,
)
expected = calc_sigmoid_focal_loss(
logit_np,
label_np,
normalizer_np,
alpha,
gamma,
reduction,
)
np.testing.assert_allclose(
static_result, expected, rtol=1e-05
)
np.testing.assert_allclose(
static_result, dy_result, rtol=1e-05
)
np.testing.assert_allclose(
dy_result, expected, rtol=1e-05
)


if __name__ == "__main__":
unittest.main()
Loading