From 7b981b43e891933b79acd19f36ce580da931f321 Mon Sep 17 00:00:00 2001 From: co63oc Date: Sat, 21 Jun 2025 17:08:57 +0800 Subject: [PATCH 1/2] Fix --- .../phi/kernels/cpu/index_put_grad_kernel.cc | 12 +++ paddle/phi/kernels/cpu/index_put_kernel.cc | 4 + .../phi/kernels/gpu/index_put_grad_kernel.cu | 13 +++ paddle/phi/kernels/gpu/index_put_kernel.cu | 4 + .../phi/kernels/xpu/index_put_grad_kernel.cc | 12 +++ paddle/phi/kernels/xpu/index_put_kernel.cc | 4 + test/legacy_test/test_index_put_op.py | 84 +++++++++++++++++++ 7 files changed, 133 insertions(+) diff --git a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc index 7fd324a0bdf6b2..21592a6949c828 100644 --- a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc @@ -182,6 +182,18 @@ void IndexPutGradKernel(const Context& dev_ctx, bool accumulate, DenseTensor* x_grad, DenseTensor* value_grad) { + if (out_grad.numel() == 0) { + dev_ctx.template Alloc(x_grad); + // Fill value_grad with 0. + if (value_grad) { + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(value_grad->dims())), + 0, + value_grad); + } + return; + } PADDLE_ENFORCE_EQ( x.dtype(), value.dtype(), diff --git a/paddle/phi/kernels/cpu/index_put_kernel.cc b/paddle/phi/kernels/cpu/index_put_kernel.cc index bd93bdc864cafa..973001ed52f5de 100644 --- a/paddle/phi/kernels/cpu/index_put_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_kernel.cc @@ -105,6 +105,10 @@ void IndexPutKernel(const Context& dev_ctx, const DenseTensor& value, bool accumulate, DenseTensor* out) { + if (out && out->numel() == 0) { + dev_ctx.template Alloc(out); + return; + } PADDLE_ENFORCE_EQ( x.dtype(), value.dtype(), diff --git a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu index 07ecfd4ac28467..07620ac5cd5917 100644 --- a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu @@ -231,6 +231,19 @@ void IndexPutGradKernel(const Context& dev_ctx, bool accumulate, DenseTensor* x_grad, DenseTensor* value_grad) { + if (out_grad.numel() == 0) { + dev_ctx.template Alloc(x_grad); + // Fill value_grad with 0. + if (value_grad) { + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(value_grad->dims())), + 0, + value_grad); + } + return; + } + PADDLE_ENFORCE_EQ( x.dtype(), value.dtype(), diff --git a/paddle/phi/kernels/gpu/index_put_kernel.cu b/paddle/phi/kernels/gpu/index_put_kernel.cu index e92c0244b4eded..034b74c5d9581d 100644 --- a/paddle/phi/kernels/gpu/index_put_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_kernel.cu @@ -116,6 +116,10 @@ void IndexPutKernel(const Context& dev_ctx, const DenseTensor& value, bool accumulate, DenseTensor* out) { + if (out && out->numel() == 0) { + dev_ctx.template Alloc(out); + return; + } PADDLE_ENFORCE_EQ( x.dtype(), value.dtype(), diff --git a/paddle/phi/kernels/xpu/index_put_grad_kernel.cc b/paddle/phi/kernels/xpu/index_put_grad_kernel.cc index 30f5138f1a85f1..fba3f42bff0990 100644 --- a/paddle/phi/kernels/xpu/index_put_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/index_put_grad_kernel.cc @@ -31,6 +31,18 @@ void IndexPutGradKernel(const Context& dev_ctx, bool accumulate, DenseTensor* x_grad, DenseTensor* value_grad) { + if (out_grad.numel() == 0) { + dev_ctx.template Alloc(x_grad); + // Fill value_grad with 0. + if (value_grad) { + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(value_grad->dims())), + 0, + value_grad); + } + return; + } PADDLE_ENFORCE_EQ( x.dtype(), value.dtype(), diff --git a/paddle/phi/kernels/xpu/index_put_kernel.cc b/paddle/phi/kernels/xpu/index_put_kernel.cc index a265489ff39b4e..84e3dca80b19c2 100644 --- a/paddle/phi/kernels/xpu/index_put_kernel.cc +++ b/paddle/phi/kernels/xpu/index_put_kernel.cc @@ -28,6 +28,10 @@ void IndexPutKernel(const Context& dev_ctx, const DenseTensor& value, bool accumulate, DenseTensor* out) { + if (out && out->numel() == 0) { + dev_ctx.template Alloc(out); + return; + } PADDLE_ENFORCE_EQ( x.dtype(), value.dtype(), diff --git a/test/legacy_test/test_index_put_op.py b/test/legacy_test/test_index_put_op.py index 702a1fbefacf67..938432a27ddf48 100644 --- a/test/legacy_test/test_index_put_op.py +++ b/test/legacy_test/test_index_put_op.py @@ -1028,5 +1028,89 @@ def init_dtype_type(self): self.index_type_pd1 = "bool" +class TestIndexPutAPI_ZeroSize(unittest.TestCase): + def setUp(self): + self.mixed_indices = False + self.is_all_false = False + self.init_dtype_type() + self.setPlace() + + if self.mixed_indices: + tmp_indices_np1 = gen_indices_np( + self.x_shape, + self.indices_shapes, + self.index_type_np, + self.is_all_false, + ) + tmp_indices_np2 = gen_indices_np( + self.x_shape, + self.indices_shapes1, + self.index_type_np1, + self.is_all_false, + ) + self.indices_np = tuple( + list(tmp_indices_np1) + list(tmp_indices_np2) + ) + else: + self.indices_np = gen_indices_np( + self.x_shape, + self.indices_shapes, + self.index_type_np, + self.is_all_false, + ) + + def init_dtype_type(self): + self.dtype_np = np.float32 + self.index_type_np = np.int64 + self.x_shape = (10, 0) + self.indices_shapes = [[10]] + self.value_shape = [1, 1] + self.dtype_pd = paddle.float32 + self.index_type_pd = paddle.int64 + + def setPlace(self): + self.place = [] + if ( + os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() + in ['1', 'true', 'on'] + or not paddle.is_compiled_with_cuda() + ): + self.place.append('cpu') + if self.dtype_np is np.float16: + self.place = [] + if paddle.is_compiled_with_cuda(): + self.place.append('gpu') + + def test_dygraph_forward(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + x_pd = paddle.randn(self.x_shape, dtype=self.dtype_pd) + x_np = x_pd.numpy() + value_pd = paddle.randn(self.value_shape, dtype=self.dtype_pd) + value_np = value_pd.numpy() + x_pd.stop_gradient = False + value_pd.stop_gradient = False + indices_pd = [ + paddle.randn(indices_shape).astype(dtype=self.index_type_pd) + for indices_shape in self.indices_shapes + ] + indices_np = [item.numpy() for item in indices_pd] + indices_pd = tuple(indices_pd) + accumulate = False + ref_res = compute_index_put_ref( + x_np, indices_np, value_np, accumulate + ) + pd_res = paddle.index_put(x_pd, indices_pd, value_pd, accumulate) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + + # check grad + pd_res.sum().backward() + np.testing.assert_allclose(x_pd.grad.shape, x_pd.shape) + np.testing.assert_allclose( + value_pd.grad.numpy(), np.zeros(value_pd.shape) + ) + + if __name__ == '__main__': unittest.main() From 0d6e0f7a7c39d92e00e60efeab167f9e0d48fc38 Mon Sep 17 00:00:00 2001 From: co63oc Date: Sat, 21 Jun 2025 17:16:15 +0800 Subject: [PATCH 2/2] Fix --- test/legacy_test/test_index_put_op.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/test/legacy_test/test_index_put_op.py b/test/legacy_test/test_index_put_op.py index 938432a27ddf48..089025383db5c4 100644 --- a/test/legacy_test/test_index_put_op.py +++ b/test/legacy_test/test_index_put_op.py @@ -1030,35 +1030,9 @@ def init_dtype_type(self): class TestIndexPutAPI_ZeroSize(unittest.TestCase): def setUp(self): - self.mixed_indices = False - self.is_all_false = False self.init_dtype_type() self.setPlace() - if self.mixed_indices: - tmp_indices_np1 = gen_indices_np( - self.x_shape, - self.indices_shapes, - self.index_type_np, - self.is_all_false, - ) - tmp_indices_np2 = gen_indices_np( - self.x_shape, - self.indices_shapes1, - self.index_type_np1, - self.is_all_false, - ) - self.indices_np = tuple( - list(tmp_indices_np1) + list(tmp_indices_np2) - ) - else: - self.indices_np = gen_indices_np( - self.x_shape, - self.indices_shapes, - self.index_type_np, - self.is_all_false, - ) - def init_dtype_type(self): self.dtype_np = np.float32 self.index_type_np = np.int64