@@ -91,9 +91,23 @@ def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
91
91
including labels, scores and bboxes.
92
92
"""
93
93
94
- if bboxes .shape [0 ] == 0 :
95
- bboxes = self .fake_bboxes
96
- bbox_num = self .fake_bbox_num
94
+ bboxes_list = []
95
+ bbox_num_list = []
96
+ id_start = 0
97
+ # add fake bbox when output is empty for each batch
98
+ for i in range (bbox_num .shape [0 ]):
99
+ if bbox_num [i ] == 0 :
100
+ bboxes_i = self .fake_bboxes
101
+ bbox_num_i = self .fake_bbox_num
102
+ id_start += 1
103
+ else :
104
+ bboxes_i = bboxes [id_start :id_start + bbox_num [i ], :]
105
+ bbox_num_i = bbox_num [i ]
106
+ id_start += bbox_num [i ]
107
+ bboxes_list .append (bboxes_i )
108
+ bbox_num_list .append (bbox_num_i )
109
+ bboxes = paddle .concat (bboxes_list )
110
+ bbox_num = paddle .concat (bbox_num_list )
97
111
98
112
origin_shape = paddle .floor (im_shape / scale_factor + 0.5 )
99
113
@@ -156,6 +170,7 @@ def paste_mask(self, masks, boxes, im_h, im_w):
156
170
"""
157
171
Paste the mask prediction to the original image.
158
172
"""
173
+
159
174
x0 , y0 , x1 , y1 = paddle .split (boxes , 4 , axis = 1 )
160
175
masks = paddle .unsqueeze (masks , [0 , 1 ])
161
176
img_y = paddle .arange (0 , im_h , dtype = 'float32' ) + 0.5
0 commit comments