Skip to content

Commit 0d32f55

Browse files
authored
fix the indexerror of conv2d_transpose (#50005)
1 parent 1755a15 commit 0d32f55

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,17 @@ def error_groups():
989989

990990
self.assertRaises(ValueError, error_groups)
991991

992+
def error_0_filter_number():
993+
out = paddle.static.nn.conv2d_transpose(
994+
input=data,
995+
groups=1,
996+
num_filters=0,
997+
filter_size=3,
998+
data_format='NCHW',
999+
)
1000+
1001+
self.assertRaises(ValueError, error_0_filter_number)
1002+
9921003

9931004
class TestConv2DTransposeRepr(unittest.TestCase):
9941005
def test_case(self):

python/paddle/static/nn/common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,6 +1542,9 @@ def conv2d_transpose(
15421542
"but received {}".format(len(input.shape))
15431543
)
15441544

1545+
if num_filters == 0:
1546+
raise ValueError("num of filters should not be 0.")
1547+
15451548
if data_format not in ['NCHW', 'NHWC']:
15461549
raise ValueError(
15471550
"Attr(data_format) of Op(paddle.static.nn.layers.conv2d_transpose) got wrong value: received "

0 commit comments

Comments
 (0)