Skip to content

Commit 8f177f8

Browse files
committed
Fix 0-size tensor support in mean when out->numel() > 0, add test.
1 parent cc707ac commit 8f177f8

File tree

5 files changed

+38
-24
lines changed

5 files changed

+38
-24
lines changed

paddle/phi/kernels/cpu/reduce_mean_kernel.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,7 @@ void MeanRawKernel(const Context& dev_ctx,
2929
bool keep_dim,
3030
bool reduce_all,
3131
DenseTensor* out) {
32-
if (out && out->numel() == 0) {
33-
dev_ctx.template Alloc<T>(out);
34-
return;
35-
}
36-
37-
if (x.numel() == 0 && out && out->dims().size() == 0) {
32+
if (x.numel() == 0) {
3833
phi::Full<T, Context>(
3934
dev_ctx, phi::IntArray(common::vectorize(out->dims())), NAN, out);
4035
return;

paddle/phi/kernels/kps/reduce_kernel.cu

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,12 @@ void MeanRawKernel(const Context& dev_ctx,
119119
bool keep_dim,
120120
bool reduce_all,
121121
DenseTensor* out) {
122-
if (out && out->numel() == 0) {
123-
dev_ctx.template Alloc<T>(out);
124-
return;
125-
}
126-
127-
if (x.numel() == 0 && out && out->dims().size() == 0) {
122+
if (x.numel() == 0) {
128123
phi::Full<T, Context>(
129124
dev_ctx, phi::IntArray(common::vectorize(out->dims())), NAN, out);
130125
return;
131126
}
127+
132128
reduce_all = recompute_reduce_all(x, dims, reduce_all);
133129
auto out_dtype = x.dtype();
134130
phi::Reduce<T, kps::AddFunctor, kps::IdentityFunctor, true>(

paddle/phi/kernels/onednn/reduce_mean_kernel.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,12 @@ void MeanRawKernel(const Context& dev_ctx,
2525
bool keep_dim,
2626
bool reduce_all,
2727
DenseTensor* out) {
28-
if (out && out->numel() == 0) {
29-
dev_ctx.template Alloc<T>(out);
30-
return;
31-
}
32-
33-
if (x.numel() == 0 && out && out->dims().size() == 0) {
28+
if (x.numel() == 0) {
3429
phi::Full<T, Context>(
3530
dev_ctx, phi::IntArray(common::vectorize(out->dims())), NAN, out);
3631
return;
3732
}
33+
3834
reduce_all = recompute_reduce_all(x, dims, reduce_all);
3935
ReduceKernel<T, Context>(dev_ctx,
4036
x,

paddle/phi/kernels/xpu/reduce_mean_kernel.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,12 @@ void MeanRawKernel(const Context& dev_ctx,
2929
bool keep_dim,
3030
bool reduce_all,
3131
DenseTensor* out) {
32-
if (out && out->numel() == 0) {
33-
dev_ctx.template Alloc<T>(out);
34-
return;
35-
}
36-
37-
if (x.numel() == 0 && out && out->dims().size() == 0) {
32+
if (x.numel() == 0) {
3833
phi::Full<T, Context>(
3934
dev_ctx, phi::IntArray(common::vectorize(out->dims())), NAN, out);
4035
return;
4136
}
37+
4238
reduce_all = recompute_reduce_all(x, dims, reduce_all);
4339
using XPUType = typename XPUTypeTrait<T>::Type;
4440
auto f = [](xpu::Context* ctx,

test/legacy_test/test_mean_op.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,37 @@ def test_check_grad(self):
10001000
)
10011001

10021002

1003+
class TestMeanOp_ZeroSize3(OpTest):
1004+
def setUp(self):
1005+
self.op_type = 'mean'
1006+
self.python_api = paddle.mean
1007+
self.init_prim_type()
1008+
self.dtype = 'float64'
1009+
self.shape = [2, 0, 4]
1010+
self.axis = 1
1011+
self.keepdim = False
1012+
self.set_attrs()
1013+
1014+
self.inputs = {'X': np.array([], dtype=self.dtype).reshape(self.shape)}
1015+
self.outputs = {
1016+
'Out': np.mean(
1017+
self.inputs["X"], axis=self.axis, keepdims=self.keepdim
1018+
)
1019+
}
1020+
1021+
def set_attrs(self):
1022+
pass
1023+
1024+
def init_prim_type(self):
1025+
self.prim_op_type = "comp"
1026+
1027+
def test_check_output(self):
1028+
self.check_output(check_pir=True, equal_nan=True)
1029+
1030+
def test_check_grad(self):
1031+
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)
1032+
1033+
10031034
if __name__ == "__main__":
10041035
paddle.enable_static()
10051036
unittest.main()

0 commit comments

Comments
 (0)