Skip to content

Commit 42d12ea

Browse files
authored
[cherry-pick] support batch_size=2 in RCNN (#4788)
1 parent 472c288 commit 42d12ea

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

ppdet/modeling/post_process.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,23 @@ def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
9191
including labels, scores and bboxes.
9292
"""
9393

94-
if bboxes.shape[0] == 0:
95-
bboxes = self.fake_bboxes
96-
bbox_num = self.fake_bbox_num
94+
bboxes_list = []
95+
bbox_num_list = []
96+
id_start = 0
97+
# add fake bbox when output is empty for each batch
98+
for i in range(bbox_num.shape[0]):
99+
if bbox_num[i] == 0:
100+
bboxes_i = self.fake_bboxes
101+
bbox_num_i = self.fake_bbox_num
102+
id_start += 1
103+
else:
104+
bboxes_i = bboxes[id_start:id_start + bbox_num[i], :]
105+
bbox_num_i = bbox_num[i]
106+
id_start += bbox_num[i]
107+
bboxes_list.append(bboxes_i)
108+
bbox_num_list.append(bbox_num_i)
109+
bboxes = paddle.concat(bboxes_list)
110+
bbox_num = paddle.concat(bbox_num_list)
97111

98112
origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
99113

@@ -156,6 +170,7 @@ def paste_mask(self, masks, boxes, im_h, im_w):
156170
"""
157171
Paste the mask prediction to the original image.
158172
"""
173+
159174
x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1)
160175
masks = paddle.unsqueeze(masks, [0, 1])
161176
img_y = paddle.arange(0, im_h, dtype='float32') + 0.5

0 commit comments

Comments
 (0)