Skip to content

Commit 852c78b

Browse files
authored
[cherry-pick] fix flatten infershape (#35398)
* fix flatten infershape; test=develop * fix flatten infershape; test=develop
1 parent e04b66f commit 852c78b

File tree

3 files changed

+44
-3
lines changed

3 files changed

+44
-3
lines changed

paddle/fluid/operators/flatten_op.cc

+15-3
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,17 @@ class FlattenOp : public framework::OperatorWithKernel {
5555
int64_t outer = 1, inner = 1;
5656
for (int i = 0; i < in_dims.size(); ++i) {
5757
if (i < axis) {
58-
outer *= in_dims[i];
58+
if (in_dims[i] == -1 || outer == -1) {
59+
outer = -1;
60+
} else {
61+
outer *= in_dims[i];
62+
}
5963
} else {
60-
inner *= in_dims[i];
64+
if (in_dims[i] == -1 || inner == -1) {
65+
inner = -1;
66+
} else {
67+
inner *= in_dims[i];
68+
}
6169
}
6270
}
6371
std::vector<int32_t> out_shape(2);
@@ -296,7 +304,11 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel {
296304
out_shape.push_back(in_dims[i]);
297305
}
298306
for (int i = start_axis; i <= stop_axis; i++) {
299-
outer *= in_dims[i];
307+
if (in_dims[i] == -1 || outer == -1) {
308+
outer = -1;
309+
} else {
310+
outer *= in_dims[i];
311+
}
300312
}
301313
out_shape.push_back(outer);
302314
for (int i = stop_axis + 1; i < in_dims_size; i++) {

python/paddle/fluid/tests/unittests/test_flatten2_op.py

+15
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import unittest
1818
import numpy as np
1919
import paddle.fluid as fluid
20+
import paddle
2021
from op_test import OpTest
2122

2223

@@ -69,6 +70,20 @@ def init_test_case(self):
6970
self.new_shape = (36, 16)
7071

7172

73+
class TestStaticFlattenInferShapePythonAPI(unittest.TestCase):
74+
def execute_api(self, x, axis=1):
75+
return fluid.layers.flatten(x, axis=axis)
76+
77+
def test_static_api(self):
78+
paddle.enable_static()
79+
main_prog = paddle.static.Program()
80+
with paddle.static.program_guard(main_prog, paddle.static.Program()):
81+
x = paddle.static.data(
82+
name="x", shape=[-1, 3, -1, -1], dtype='float32')
83+
out = self.execute_api(x, axis=2)
84+
self.assertTrue((-1, -1) == out.shape)
85+
86+
7287
class TestFlatten2OpError(unittest.TestCase):
7388
def test_errors(self):
7489
with fluid.program_guard(fluid.Program(), fluid.Program()):

python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py

+14
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,20 @@ def test_static_api(self):
201201
self.assertTrue((2, 3, 16) == fetch_out[0].shape)
202202

203203

204+
class TestStaticFlattenInferShapePythonAPI(unittest.TestCase):
205+
def execute_api(self, x, start_axis=0, stop_axis=-1):
206+
return paddle.flatten(x, start_axis, stop_axis)
207+
208+
def test_static_api(self):
209+
paddle.enable_static()
210+
main_prog = paddle.static.Program()
211+
with paddle.static.program_guard(main_prog, paddle.static.Program()):
212+
x = paddle.static.data(
213+
name="x", shape=[-1, 3, -1, -1], dtype='float32')
214+
out = self.execute_api(x, start_axis=2, stop_axis=3)
215+
self.assertTrue((-1, 3, -1) == out.shape)
216+
217+
204218
class TestStaticInplaceFlattenPythonAPI(TestStaticFlattenPythonAPI):
205219
def execute_api(self, x, start_axis=0, stop_axis=-1):
206220
return x.flatten_(start_axis, stop_axis)

0 commit comments

Comments
 (0)