Skip to content

Commit 343d1c1

Browse files
Liyulingyuepangengzheng
authored andcommitted
fix the div 0 errors in psroi_pool (PaddlePaddle#49965)
* fix the div 0 errors in psroi_pool * fix case 7 * rool back sth.
1 parent 1211722 commit 343d1c1

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

python/paddle/fluid/tests/unittests/test_psroi_pool_op.py

+16
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,22 @@ def test_channel_error():
339339
self.assertRaises(ValueError, test_channel_error)
340340

341341

342+
class TestPSROIPoolZeroDivError(unittest.TestCase):
343+
def setUp(self):
344+
paddle.disable_static()
345+
self.x = paddle.uniform([2, 490, 28, 28], dtype='float32')
346+
self.boxes = paddle.to_tensor(
347+
[[1, 5, 8, 10], [4, 2, 6, 7], [12, 12, 19, 21]], dtype='float32'
348+
)
349+
self.boxes_num = paddle.to_tensor([1, 2], dtype='int32')
350+
351+
def test_errors(self):
352+
def test_zero_div_error():
353+
paddle.vision.ops.psroi_pool(self.x, self.boxes, self.boxes_num, 0)
354+
355+
self.assertRaises(ValueError, test_zero_div_error)
356+
357+
342358
class TestPSROIPoolStaticAPI(unittest.TestCase):
343359
def setUp(self):
344360
paddle.enable_static()

python/paddle/vision/ops.py

+2
Original file line numberDiff line numberDiff line change
@@ -1424,6 +1424,8 @@ def psroi_pool(x, boxes, boxes_num, output_size, spatial_scale=1.0, name=None):
14241424
output_size = (output_size, output_size)
14251425
pooled_height, pooled_width = output_size
14261426
assert len(x.shape) == 4, "Input features with shape should be (N, C, H, W)"
1427+
if pooled_height * pooled_width == 0:
1428+
raise ValueError('output_size should not contain 0.')
14271429
output_channels = int(x.shape[1] / (pooled_height * pooled_width))
14281430
if in_dygraph_mode():
14291431
return _C_ops.psroi_pool(

0 commit comments

Comments
 (0)