Skip to content

Commit da11aa4

Browse files
authored
Fix Python IndexError of case13: paddle.static.nn.batch_norm (#50011)
* add channel_num check for paddle.static.nn.batch_norm * fix bugs * fix bugs
1 parent 0d32f55 commit da11aa4

File tree

3 files changed

+11
-1
lines changed

3 files changed

+11
-1
lines changed

python/paddle/fluid/tests/unittests/test_batch_norm_op.py

+4
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,10 @@ def test_errors(self):
768768
)
769769
self.assertRaises(TypeError, paddle.static.nn.batch_norm, x2)
770770

771+
# the first dimension of input for batch_norm must between [2d, 5d].
772+
x3 = paddle.static.data("", shape=[0], dtype="float32")
773+
self.assertRaises(ValueError, paddle.static.nn.batch_norm, x3)
774+
771775

772776
class TestDygraphBatchNormAPIError(unittest.TestCase):
773777
def test_errors(self):

python/paddle/fluid/tests/unittests/test_fold_op.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def test_errors(self):
179179
with program_guard(Program(), Program()):
180180

181181
def test_input_shape():
182-
# input_shpae must be 3-D
182+
# input_shape must be 3-D
183183
x = paddle.randn(shape=[2, 3, 6, 7], dtype="float32")
184184
out = fold(x, output_sizes=[2, 3], kernel_sizes=[2, 2])
185185

python/paddle/static/nn/common.py

+6
Original file line numberDiff line numberDiff line change
@@ -2731,6 +2731,12 @@ def batch_norm(
27312731
dtype = core.VarDesc.VarType.FP32
27322732

27332733
input_shape = input.shape
2734+
if len(input.shape) < 2 or len(input.shape) > 5:
2735+
raise ValueError(
2736+
'expected 2D or 3D or 4D or 5D input (got {}D input, input shape is: {})'.format(
2737+
len(input.shape), input_shape
2738+
)
2739+
)
27342740
if data_layout == 'NCHW':
27352741
channel_num = input_shape[1]
27362742
else:

0 commit comments

Comments
 (0)