Skip to content

Commit 62dcf8c

Browse files
authored
[0-size Tensor No.144] Add 0-size Tensor support for paddle.nextafter API. (#73008)
* fix * add test * fix * merge and fix * fix error * fix error * fix error
1 parent 598fd26 commit 62dcf8c

File tree

3 files changed

+35
-0
lines changed

3 files changed

+35
-0
lines changed

paddle/phi/kernels/cpu/elementwise_kernel.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ void NextafterKernel(const Context& dev_ctx,
111111
const DenseTensor& x,
112112
const DenseTensor& y,
113113
DenseTensor* out) {
114+
if (x.numel() == 0 || y.numel() == 0) {
115+
dev_ctx.template Alloc<T>(out);
116+
return;
117+
}
114118
dev_ctx.template Alloc<T>(out);
115119
auto x_dims = x.dims();
116120
auto y_dims = y.dims();
@@ -210,5 +214,6 @@ PD_REGISTER_KERNEL(copysign,
210214
double,
211215
phi::dtype::float16,
212216
phi::dtype::bfloat16) {}
217+
213218
PD_REGISTER_KERNEL(
214219
nextafter, CPU, ALL_LAYOUT, phi::NextafterKernel, float, double) {}

paddle/phi/kernels/kps/elementwise_kernel.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ void NextafterKernel(const Context& dev_ctx,
201201
const DenseTensor& x,
202202
const DenseTensor& y,
203203
DenseTensor* out) {
204+
if (x.numel() == 0 || y.numel() == 0) {
205+
dev_ctx.template Alloc<T>(out);
206+
return;
207+
}
204208
std::vector<const DenseTensor*> inputs = {&x, &y};
205209
std::vector<DenseTensor*> outputs = {out};
206210
dev_ctx.template Alloc<T>(out);

test/legacy_test/test_nextafter_op.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,31 @@ def init_shape(self):
149149
self.y_shape = (1,)
150150

151151

152+
class TestNextafterOPZeroDim1(TestNextafterOP):
153+
def setUp(self):
154+
self.op_type = "nextafter"
155+
self.python_api = paddle.nextafter
156+
self.init_dtype()
157+
158+
x = np.random.rand(0, 3, 2).astype(self.dtype)
159+
y = np.random.rand(0, 3, 2).astype(self.dtype)
160+
out = np.nextafter(x, y)
161+
self.inputs = {'x': x, 'y': y}
162+
self.outputs = {'out': out}
163+
164+
165+
class TestNextafterOPZeroDim2(TestNextafterOP):
166+
def setUp(self):
167+
self.op_type = "nextafter"
168+
self.python_api = paddle.nextafter
169+
self.init_dtype()
170+
171+
x = np.random.rand(4, 0, 2).astype(self.dtype)
172+
y = np.random.rand(4, 0, 2).astype(self.dtype)
173+
out = np.nextafter(x, y)
174+
self.inputs = {'x': x, 'y': y}
175+
self.outputs = {'out': out}
176+
177+
152178
if __name__ == "__main__":
153179
unittest.main()

0 commit comments

Comments
 (0)