diff --git a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc index 7fd324a0bdf6b..21592a6949c82 100644 --- a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc @@ -182,6 +182,18 @@ void IndexPutGradKernel(const Context& dev_ctx, bool accumulate, DenseTensor* x_grad, DenseTensor* value_grad) { + if (out_grad.numel() == 0) { + dev_ctx.template Alloc(x_grad); + // Fill value_grad with 0. + if (value_grad) { + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(value_grad->dims())), + 0, + value_grad); + } + return; + } PADDLE_ENFORCE_EQ( x.dtype(), value.dtype(), diff --git a/paddle/phi/kernels/cpu/index_put_kernel.cc b/paddle/phi/kernels/cpu/index_put_kernel.cc index bd93bdc864caf..973001ed52f5d 100644 --- a/paddle/phi/kernels/cpu/index_put_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_kernel.cc @@ -105,6 +105,10 @@ void IndexPutKernel(const Context& dev_ctx, const DenseTensor& value, bool accumulate, DenseTensor* out) { + if (out && out->numel() == 0) { + dev_ctx.template Alloc(out); + return; + } PADDLE_ENFORCE_EQ( x.dtype(), value.dtype(), diff --git a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu index 07ecfd4ac2846..07620ac5cd591 100644 --- a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu @@ -231,6 +231,19 @@ void IndexPutGradKernel(const Context& dev_ctx, bool accumulate, DenseTensor* x_grad, DenseTensor* value_grad) { + if (out_grad.numel() == 0) { + dev_ctx.template Alloc(x_grad); + // Fill value_grad with 0. + if (value_grad) { + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(value_grad->dims())), + 0, + value_grad); + } + return; + } + PADDLE_ENFORCE_EQ( x.dtype(), value.dtype(), diff --git a/paddle/phi/kernels/gpu/index_put_kernel.cu b/paddle/phi/kernels/gpu/index_put_kernel.cu index e92c0244b4ede..034b74c5d9581 100644 --- a/paddle/phi/kernels/gpu/index_put_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_kernel.cu @@ -116,6 +116,10 @@ void IndexPutKernel(const Context& dev_ctx, const DenseTensor& value, bool accumulate, DenseTensor* out) { + if (out && out->numel() == 0) { + dev_ctx.template Alloc(out); + return; + } PADDLE_ENFORCE_EQ( x.dtype(), value.dtype(), diff --git a/paddle/phi/kernels/xpu/index_put_grad_kernel.cc b/paddle/phi/kernels/xpu/index_put_grad_kernel.cc index 30f5138f1a85f..fba3f42bff099 100644 --- a/paddle/phi/kernels/xpu/index_put_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/index_put_grad_kernel.cc @@ -31,6 +31,18 @@ void IndexPutGradKernel(const Context& dev_ctx, bool accumulate, DenseTensor* x_grad, DenseTensor* value_grad) { + if (out_grad.numel() == 0) { + dev_ctx.template Alloc(x_grad); + // Fill value_grad with 0. + if (value_grad) { + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(value_grad->dims())), + 0, + value_grad); + } + return; + } PADDLE_ENFORCE_EQ( x.dtype(), value.dtype(), diff --git a/paddle/phi/kernels/xpu/index_put_kernel.cc b/paddle/phi/kernels/xpu/index_put_kernel.cc index a265489ff39b4..84e3dca80b19c 100644 --- a/paddle/phi/kernels/xpu/index_put_kernel.cc +++ b/paddle/phi/kernels/xpu/index_put_kernel.cc @@ -28,6 +28,10 @@ void IndexPutKernel(const Context& dev_ctx, const DenseTensor& value, bool accumulate, DenseTensor* out) { + if (out && out->numel() == 0) { + dev_ctx.template Alloc(out); + return; + } PADDLE_ENFORCE_EQ( x.dtype(), value.dtype(), diff --git a/test/legacy_test/test_index_put_op.py b/test/legacy_test/test_index_put_op.py index 702a1fbefacf6..089025383db5c 100644 --- a/test/legacy_test/test_index_put_op.py +++ b/test/legacy_test/test_index_put_op.py @@ -1028,5 +1028,63 @@ def init_dtype_type(self): self.index_type_pd1 = "bool" +class TestIndexPutAPI_ZeroSize(unittest.TestCase): + def setUp(self): + self.init_dtype_type() + self.setPlace() + + def init_dtype_type(self): + self.dtype_np = np.float32 + self.index_type_np = np.int64 + self.x_shape = (10, 0) + self.indices_shapes = [[10]] + self.value_shape = [1, 1] + self.dtype_pd = paddle.float32 + self.index_type_pd = paddle.int64 + + def setPlace(self): + self.place = [] + if ( + os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() + in ['1', 'true', 'on'] + or not paddle.is_compiled_with_cuda() + ): + self.place.append('cpu') + if self.dtype_np is np.float16: + self.place = [] + if paddle.is_compiled_with_cuda(): + self.place.append('gpu') + + def test_dygraph_forward(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + x_pd = paddle.randn(self.x_shape, dtype=self.dtype_pd) + x_np = x_pd.numpy() + value_pd = paddle.randn(self.value_shape, dtype=self.dtype_pd) + value_np = value_pd.numpy() + x_pd.stop_gradient = False + value_pd.stop_gradient = False + indices_pd = [ + paddle.randn(indices_shape).astype(dtype=self.index_type_pd) + for indices_shape in self.indices_shapes + ] + indices_np = [item.numpy() for item in indices_pd] + indices_pd = tuple(indices_pd) + accumulate = False + ref_res = compute_index_put_ref( + x_np, indices_np, value_np, accumulate + ) + pd_res = paddle.index_put(x_pd, indices_pd, value_pd, accumulate) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + + # check grad + pd_res.sum().backward() + np.testing.assert_allclose(x_pd.grad.shape, x_pd.shape) + np.testing.assert_allclose( + value_pd.grad.numpy(), np.zeros(value_pd.shape) + ) + + if __name__ == '__main__': unittest.main()