diff --git a/paddle/phi/kernels/cpu/elementwise_kernel.cc b/paddle/phi/kernels/cpu/elementwise_kernel.cc index 68a70aa273052..dbffe04a1d536 100644 --- a/paddle/phi/kernels/cpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_kernel.cc @@ -111,6 +111,10 @@ void NextafterKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* out) { + if (x.numel() == 0 || y.numel() == 0) { + dev_ctx.template Alloc(out); + return; + } dev_ctx.template Alloc(out); auto x_dims = x.dims(); auto y_dims = y.dims(); @@ -210,5 +214,6 @@ PD_REGISTER_KERNEL(copysign, double, phi::dtype::float16, phi::dtype::bfloat16) {} + PD_REGISTER_KERNEL( nextafter, CPU, ALL_LAYOUT, phi::NextafterKernel, float, double) {} diff --git a/paddle/phi/kernels/kps/elementwise_kernel.cu b/paddle/phi/kernels/kps/elementwise_kernel.cu index 5986df3b60ece..94033f1fa1015 100644 --- a/paddle/phi/kernels/kps/elementwise_kernel.cu +++ b/paddle/phi/kernels/kps/elementwise_kernel.cu @@ -201,6 +201,10 @@ void NextafterKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* out) { + if (x.numel() == 0 || y.numel() == 0) { + dev_ctx.template Alloc(out); + return; + } std::vector inputs = {&x, &y}; std::vector outputs = {out}; dev_ctx.template Alloc(out); diff --git a/test/legacy_test/test_nextafter_op.py b/test/legacy_test/test_nextafter_op.py index 6ea7d2327988b..daacd63cef820 100644 --- a/test/legacy_test/test_nextafter_op.py +++ b/test/legacy_test/test_nextafter_op.py @@ -149,5 +149,31 @@ def init_shape(self): self.y_shape = (1,) +class TestNextafterOPZeroDim1(TestNextafterOP): + def setUp(self): + self.op_type = "nextafter" + self.python_api = paddle.nextafter + self.init_dtype() + + x = np.random.rand(0, 3, 2).astype(self.dtype) + y = np.random.rand(0, 3, 2).astype(self.dtype) + out = np.nextafter(x, y) + self.inputs = {'x': x, 'y': y} + self.outputs = {'out': out} + + +class TestNextafterOPZeroDim2(TestNextafterOP): + def setUp(self): + self.op_type = "nextafter" + self.python_api = paddle.nextafter + self.init_dtype() + + x = np.random.rand(4, 0, 2).astype(self.dtype) + y = np.random.rand(4, 0, 2).astype(self.dtype) + out = np.nextafter(x, y) + self.inputs = {'x': x, 'y': y} + self.outputs = {'out': out} + + if __name__ == "__main__": unittest.main()