Skip to content

Commit 34fafb1

Browse files
authored
[cherry-pick]Fix paddle.queeze_ bug (#49937)
* Fix paddle.queeze_ bug (#49903) * fix queeze_ bug * fix slove use squeeze_kernel * fix slove use squeeze_kernel * fix slove use squeeze_kernel * add test case * Update squeeze_kernel.h
1 parent 0699afb commit 34fafb1

File tree

4 files changed

+47
-14
lines changed

4 files changed

+47
-14
lines changed

paddle/phi/kernels/impl/solve_kernel_impl.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx,
169169
out_tmp.Resize(out->dims());
170170
out_tmp = *out;
171171

172-
phi::SqueezeKernel<T, Context>(dev_ctx, out_tmp, {-1}, out);
172+
phi::Squeeze<T, Context>(dev_ctx, out_tmp, {-1}, out);
173173
} else {
174174
PADDLE_ENFORCE_EQ(
175175
x_dim[x_dim_size - 1],

paddle/phi/kernels/impl/squeeze_kernel_impl.h

+1-5
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,7 @@ void SqueezeKernel(const Context& dev_ctx,
2323
const DenseTensor& x,
2424
const IntArray& axes,
2525
DenseTensor* out) {
26-
auto x_dims = x.dims();
27-
std::vector<int32_t> tmp(axes.GetData().begin(), axes.GetData().end());
28-
auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true);
29-
out->Resize(out_dims);
30-
26+
auto out_dims = out->dims();
3127
dev_ctx.template Alloc<T>(out);
3228
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
3329
out->Resize(out_dims); // copy will reset the dims.

paddle/phi/kernels/squeeze_kernel.h

+11
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "paddle/phi/common/int_array.h"
1919
#include "paddle/phi/core/dense_tensor.h"
20+
#include "paddle/phi/infermeta/unary.h"
2021

2122
namespace phi {
2223

@@ -33,4 +34,14 @@ void SqueezeWithXShapeKernel(const Context& dev_ctx,
3334
DenseTensor* out,
3435
DenseTensor* xshape);
3536

37+
template <typename T, typename Context>
38+
void Squeeze(const Context& dev_ctx,
39+
const DenseTensor& x,
40+
const IntArray& axes,
41+
DenseTensor* out) {
42+
MetaTensor meta_out(out);
43+
SqueezeInferMeta(x, axes, &meta_out);
44+
SqueezeKernel<T, Context>(dev_ctx, x, axes, out);
45+
}
46+
3647
} // namespace phi

python/paddle/fluid/tests/unittests/test_squeeze2_op.py

+34-8
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
# Correct: General.
3030
class TestSqueezeOp(OpTest):
31-
3231
def setUp(self):
3332
self.op_type = "squeeze2"
3433
self.python_api = paddle.squeeze
@@ -40,7 +39,7 @@ def setUp(self):
4039
self.init_attrs()
4140
self.outputs = {
4241
"Out": self.inputs["X"].reshape(self.new_shape),
43-
"XShape": np.random.random(self.ori_shape).astype("float64")
42+
"XShape": np.random.random(self.ori_shape).astype("float64"),
4443
}
4544

4645
def test_check_output(self):
@@ -60,7 +59,6 @@ def init_attrs(self):
6059

6160
# Correct: There is mins axis.
6261
class TestSqueezeOp1(TestSqueezeOp):
63-
6462
def init_test_case(self):
6563
self.ori_shape = (1, 20, 1, 5)
6664
self.axes = (0, -2)
@@ -69,7 +67,6 @@ def init_test_case(self):
6967

7068
# Correct: No axes input.
7169
class TestSqueezeOp2(TestSqueezeOp):
72-
7370
def init_test_case(self):
7471
self.ori_shape = (1, 20, 1, 5)
7572
self.axes = ()
@@ -78,15 +75,13 @@ def init_test_case(self):
7875

7976
# Correct: Just part of axes be squeezed.
8077
class TestSqueezeOp3(TestSqueezeOp):
81-
8278
def init_test_case(self):
8379
self.ori_shape = (6, 1, 5, 1, 4, 1)
8480
self.axes = (1, -1)
8581
self.new_shape = (6, 5, 1, 4)
8682

8783

8884
class TestSqueeze2AxesTensor(UnittestBase):
89-
9085
def init_info(self):
9186
self.shapes = [[2, 3, 4]]
9287
self.save_path = os.path.join(self.temp_dir.name, 'squeeze_tensor')
@@ -123,7 +118,6 @@ def test_static(self):
123118

124119

125120
class TestSqueeze2AxesTensorList(UnittestBase):
126-
127121
def init_info(self):
128122
self.shapes = [[2, 3, 4]]
129123
self.save_path = os.path.join(self.temp_dir.name, 'squeeze_tensor')
@@ -140,7 +134,7 @@ def test_static(self):
140134
# axes is a list[Variable]
141135
axes = [
142136
paddle.full([1], 0, dtype='int32'),
143-
paddle.full([1], 2, dtype='int32')
137+
paddle.full([1], 2, dtype='int32'),
144138
]
145139
out = paddle.squeeze(feat, axes)
146140
out2 = paddle.fluid.layers.squeeze(feat, axes)
@@ -162,5 +156,37 @@ def test_static(self):
162156
self.assertEqual(infer_out.shape, (2, 3, 10))
163157

164158

159+
# test api
160+
class TestSqueezeAPI(unittest.TestCase):
161+
def setUp(self):
162+
self.executed_api()
163+
164+
def executed_api(self):
165+
self.squeeze = paddle.squeeze
166+
167+
def test_api(self):
168+
paddle.disable_static()
169+
input_data = np.random.random([3, 2, 1]).astype("float32")
170+
x = paddle.to_tensor(input_data)
171+
out = self.squeeze(x, axis=2)
172+
out.backward()
173+
174+
self.assertEqual(out.shape, [3, 2])
175+
176+
paddle.enable_static()
177+
178+
def test_error(self):
179+
def test_axes_type():
180+
x2 = paddle.static.data(name="x2", shape=[2, 1, 25], dtype="int32")
181+
self.squeeze(x2, axis=2.1)
182+
183+
self.assertRaises(TypeError, test_axes_type)
184+
185+
186+
class TestSqueezeInplaceAPI(TestSqueezeAPI):
187+
def executed_api(self):
188+
self.squeeze = paddle.squeeze_
189+
190+
165191
if __name__ == "__main__":
166192
unittest.main()

0 commit comments

Comments
 (0)