Skip to content

Commit 7a79fd8

Browse files
authored
Fix unsqueeze with empty axis bug (#51828)
1 parent 6ac7cab commit 7a79fd8

File tree

4 files changed

+21
-15
lines changed

4 files changed

+21
-15
lines changed

paddle/phi/infermeta/unary.cc

+3-8
Original file line numberDiff line numberDiff line change
@@ -4791,15 +4791,10 @@ void UnsqueezeInferMeta(const MetaTensor& x,
47914791
std::vector<int64_t> vec_out_dims(output_size, -1);
47924792
out->set_dtype(x.dtype());
47934793
out->set_dims(phi::make_ddim(vec_out_dims));
4794-
} else if (!axes.GetData().empty()) {
4795-
std::vector<int32_t> tmp;
4796-
tmp.reserve(axes.GetData().size());
4797-
std::for_each(axes.GetData().begin(),
4798-
axes.GetData().end(),
4799-
[&tmp](const int64_t& t) { tmp.push_back(t); });
4800-
auto out_dims = funcs::GetUnsqueezeShape(tmp, x_dims);
4794+
} else {
4795+
auto out_dims = funcs::GetUnsqueezeShape(axes.GetData(), x_dims);
48014796
out->set_dims(out_dims);
4802-
if (x_dims[0] == out_dims[0]) {
4797+
if (x_dims.size() > 0 && x_dims[0] == out_dims[0]) {
48034798
out->share_lod(x);
48044799
}
48054800
out->set_dtype(x.dtype());

paddle/phi/kernels/funcs/unsqueeze.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ inline DDim GetOutputSqueezeShape(const std::vector<int> squeeze_dims,
103103
return phi::make_ddim(output_shape);
104104
}
105105

106-
inline DDim GetUnsqueezeShape(const std::vector<int> unsqz_dims,
106+
inline DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims,
107107
const DDim& in_dims) {
108108
int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
109109
int cur_output_size = in_dims.size();

paddle/phi/kernels/unsqueeze_kernel.cc

+1-6
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,7 @@ void UnsqueezeInferKernel(const Context& dev_ctx,
2828
auto x_dims = x.dims();
2929
auto out_dims = out->dims();
3030
if (axes.FromTensor()) {
31-
std::vector<int32_t> tmp;
32-
tmp.reserve(axes.GetData().size());
33-
std::for_each(axes.GetData().begin(),
34-
axes.GetData().end(),
35-
[&tmp](const int64_t& t) { tmp.push_back(t); });
36-
out_dims = funcs::GetUnsqueezeShape(tmp, x_dims);
31+
out_dims = funcs::GetUnsqueezeShape(axes.GetData(), x_dims);
3732
}
3833
out->Resize(out_dims);
3934
dev_ctx.template Alloc<T>(out);

python/paddle/fluid/tests/unittests/test_unsqueeze_op.py

+16
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,22 @@ def init_test_case(self):
108108
self.new_shape = (10, 1, 1, 2, 5, 1)
109109

110110

111+
# axis is empty, x is ND
112+
class TestUnsqueezeOp5(TestUnsqueezeOp):
113+
def init_test_case(self):
114+
self.ori_shape = ()
115+
self.axes = ()
116+
self.new_shape = ()
117+
118+
119+
# axis is empty, x is 0D
120+
class TestUnsqueezeOp6(TestUnsqueezeOp):
121+
def init_test_case(self):
122+
self.ori_shape = (10, 2, 5)
123+
self.axes = ()
124+
self.new_shape = (10, 2, 5)
125+
126+
111127
class TestUnsqueezeOp_ZeroDim1(TestUnsqueezeOp):
112128
def init_test_case(self):
113129
self.ori_shape = ()

0 commit comments

Comments
 (0)