Skip to content

Commit 4098dbc

Browse files
committed
add unittest
1 parent 48ab823 commit 4098dbc

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

python/paddle/fluid/tests/unittests/test_pool1d_api.py

+10
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,16 @@ def run_zero_stride():
429429

430430
self.assertRaises(ValueError, run_zero_stride)
431431

432+
def run_zero_tuple_stride():
433+
with fluid.dygraph.guard():
434+
array = np.array([1], dtype=np.float32)
435+
x = paddle.to_tensor(
436+
np.reshape(array, [1, 1, 1]), dtype='float32'
437+
)
438+
out = F.max_pool1d(x, 1, stride=(0))
439+
440+
self.assertRaises(ValueError, run_zero_tuple_stride)
441+
432442

433443
if __name__ == '__main__':
434444
unittest.main()

python/paddle/fluid/tests/unittests/test_pool2d_api.py

+12
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,18 @@ def run_zero_stride():
609609

610610
self.assertRaises(ValueError, run_zero_stride)
611611

612+
def run_zero_tuple_stride():
613+
with fluid.dygraph.guard():
614+
array = np.array([1], dtype=np.float32)
615+
x = paddle.to_tensor(
616+
np.reshape(array, [1, 1, 1, 1]), dtype='float32'
617+
)
618+
out = max_pool2d(
619+
x, 1, stride=(0, 0), return_mask=False, data_format='NHWC'
620+
)
621+
622+
self.assertRaises(ValueError, run_zero_tuple_stride)
623+
612624

613625
if __name__ == '__main__':
614626
unittest.main()

python/paddle/fluid/tests/unittests/test_pool3d_api.py

+10
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,16 @@ def run_zero_stride():
575575

576576
self.assertRaises(ValueError, run_zero_stride)
577577

578+
def run_zero_tuple_stride():
579+
with fluid.dygraph.guard():
580+
array = np.array([1], dtype=np.float32)
581+
x = paddle.to_tensor(
582+
np.reshape(array, [1, 1, 1, 1, 1]), dtype='float32'
583+
)
584+
out = max_pool3d(x, 1, stride=(0, 0, 0), ceil_mode=False)
585+
586+
self.assertRaises(ValueError, run_zero_tuple_stride)
587+
578588

579589
if __name__ == '__main__':
580590
unittest.main()

0 commit comments

Comments
 (0)