@@ -58,7 +58,6 @@ def __init__(self, args):
58
58
logger .warning (
59
59
"When args.layout is false, args.ocr is automatically set to false"
60
60
)
61
- args .drop_score = 0
62
61
# init model
63
62
self .layout_predictor = None
64
63
self .text_system = None
@@ -93,6 +92,7 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
93
92
'all' : 0
94
93
}
95
94
start = time .time ()
95
+
96
96
if self .image_orientation_predictor is not None :
97
97
tic = time .time ()
98
98
cls_result = self .image_orientation_predictor .predict (
@@ -108,6 +108,7 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
108
108
img = cv2 .rotate (img , cv_rotate_code [angle ])
109
109
toc = time .time ()
110
110
time_dict ['image_orientation' ] = toc - tic
111
+
111
112
if self .mode == 'structure' :
112
113
ori_im = img .copy ()
113
114
if self .layout_predictor is not None :
@@ -116,6 +117,20 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
116
117
else :
117
118
h , w = ori_im .shape [:2 ]
118
119
layout_res = [dict (bbox = None , label = 'table' )]
120
+
121
+ # As reported in issues such as #10270 and #11665, the old
122
+ # implementation, which recognizes texts from the layout regions,
123
+ # has problems with OCR recognition accuracy.
124
+ #
125
+ # To enhance the OCR recognition accuracy, we implement a patch fix
126
+ # that first detect all text regions by using the text_detector
127
+ # and then recognize the texts from the text regions (intersecting
128
+ # with the layout regions) by using the text_recognizer.
129
+ dt_boxes = []
130
+ if self .text_system is not None :
131
+ dt_boxes , elapse = self .text_system .text_detector (img )
132
+ time_dict ['det' ] = elapse
133
+
119
134
res_list = []
120
135
for region in layout_res :
121
136
res = ''
@@ -126,6 +141,8 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
126
141
else :
127
142
x1 , y1 , x2 , y2 = 0 , 0 , w , h
128
143
roi_img = ori_im
144
+ bbox = [x1 , y1 , x2 , y2 ]
145
+
129
146
if region ['label' ] == 'table' :
130
147
if self .table_system is not None :
131
148
res , table_time_dict = self .table_system (
@@ -136,66 +153,99 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
136
153
time_dict ['rec' ] += table_time_dict ['rec' ]
137
154
else :
138
155
if self .text_system is not None :
139
- if self .recovery :
140
- wht_im = np .ones (ori_im .shape , dtype = ori_im .dtype )
141
- wht_im [y1 :y2 , x1 :x2 , :] = roi_img
142
- filter_boxes , filter_rec_res , ocr_time_dict = self .text_system (
143
- wht_im )
144
- else :
145
- filter_boxes , filter_rec_res , ocr_time_dict = self .text_system (
146
- roi_img )
156
+ res , ocr_time_dict = self ._predict_text (ori_im , roi_img , bbox , dt_boxes )
147
157
time_dict ['det' ] += ocr_time_dict ['det' ]
148
158
time_dict ['rec' ] += ocr_time_dict ['rec' ]
149
159
150
- # remove style char,
151
- # when using the recognition model trained on the PubtabNet dataset,
152
- # it will recognize the text format in the table, such as <b>
153
- style_token = [
154
- '<strike>' , '<strike>' , '<sup>' , '</sub>' , '<b>' ,
155
- '</b>' , '<sub>' , '</sup>' , '<overline>' ,
156
- '</overline>' , '<underline>' , '</underline>' , '<i>' ,
157
- '</i>'
158
- ]
159
- res = []
160
- for box , rec_res in zip (filter_boxes , filter_rec_res ):
161
- rec_str , rec_conf = rec_res [0 ], rec_res [1 ]
162
- for token in style_token :
163
- if token in rec_str :
164
- rec_str = rec_str .replace (token , '' )
165
- if not self .recovery :
166
- box += [x1 , y1 ]
167
- if self .return_word_box :
168
- word_box_content_list , word_box_list = cal_ocr_word_box (rec_str , box , rec_res [2 ])
169
- res .append ({
170
- 'text' : rec_str ,
171
- 'confidence' : float (rec_conf ),
172
- 'text_region' : box .tolist (),
173
- 'text_word' : word_box_content_list ,
174
- 'text_word_region' : word_box_list
175
- })
176
- else :
177
- res .append ({
178
- 'text' : rec_str ,
179
- 'confidence' : float (rec_conf ),
180
- 'text_region' : box .tolist ()
181
- })
182
160
res_list .append ({
183
161
'type' : region ['label' ].lower (),
184
- 'bbox' : [ x1 , y1 , x2 , y2 ] ,
162
+ 'bbox' : bbox ,
185
163
'img' : roi_img ,
186
164
'res' : res ,
187
165
'img_idx' : img_idx
188
166
})
167
+
189
168
end = time .time ()
190
169
time_dict ['all' ] = end - start
191
170
return res_list , time_dict
171
+
192
172
elif self .mode == 'kie' :
193
173
re_res , elapse = self .kie_predictor (img )
194
174
time_dict ['kie' ] = elapse
195
175
time_dict ['all' ] = elapse
196
176
return re_res [0 ], time_dict
177
+
197
178
return None , None
198
179
180
+ def _predict_text (self , ori_im , roi_img , bbox , dt_boxes ):
181
+ x1 , y1 , x2 , y2 = bbox
182
+
183
+ if self .recovery :
184
+ wht_im = np .ones (ori_im .shape , dtype = ori_im .dtype )
185
+ wht_im [y1 :y2 , x1 :x2 , :] = roi_img
186
+ filter_boxes , filter_rec_res , ocr_time_dict = self .text_system (
187
+ wht_im )
188
+ else :
189
+ # Filter the text regions that intersect with the current bbox.
190
+ intersecting_dt_boxes = self ._filter_boxes (dt_boxes , bbox )
191
+ # Recognize texts from these intersecting text regions.
192
+ filter_boxes , filter_rec_res , ocr_time_dict = self .text_system (
193
+ ori_im , dt_boxes = intersecting_dt_boxes )
194
+
195
+ # remove style char,
196
+ # when using the recognition model trained on the PubtabNet dataset,
197
+ # it will recognize the text format in the table, such as <b>
198
+ style_token = [
199
+ '<strike>' , '<strike>' , '<sup>' , '</sub>' , '<b>' ,
200
+ '</b>' , '<sub>' , '</sup>' , '<overline>' ,
201
+ '</overline>' , '<underline>' , '</underline>' , '<i>' ,
202
+ '</i>'
203
+ ]
204
+ res = []
205
+ for box , rec_res in zip (filter_boxes , filter_rec_res ):
206
+ rec_str , rec_conf = rec_res [0 ], rec_res [1 ]
207
+ for token in style_token :
208
+ if token in rec_str :
209
+ rec_str = rec_str .replace (token , '' )
210
+ # if not self.recovery:
211
+ # box += [x1, y1]
212
+ if self .return_word_box :
213
+ word_box_content_list , word_box_list = cal_ocr_word_box (rec_str , box , rec_res [2 ])
214
+ res .append ({
215
+ 'text' : rec_str ,
216
+ 'confidence' : float (rec_conf ),
217
+ 'text_region' : box .tolist (),
218
+ 'text_word' : word_box_content_list ,
219
+ 'text_word_region' : word_box_list
220
+ })
221
+ else :
222
+ res .append ({
223
+ 'text' : rec_str ,
224
+ 'confidence' : float (rec_conf ),
225
+ 'text_region' : box .tolist ()
226
+ })
227
+ return res , ocr_time_dict
228
+
229
+ def _filter_boxes (self , dt_boxes , bbox ):
230
+ boxes = []
231
+
232
+ for idx in range (len (dt_boxes )):
233
+ box = dt_boxes [idx ]
234
+ rect = box [0 ][0 ], box [0 ][1 ], box [2 ][0 ], box [2 ][1 ]
235
+ if self ._has_intersection (bbox , rect ):
236
+ boxes .append (box .tolist ())
237
+
238
+ return np .array (boxes , np .float32 ).reshape ((len (boxes ), 4 , 2 ))
239
+
240
+ def _has_intersection (self , rect1 , rect2 ):
241
+ x_min1 , y_min1 , x_max1 , y_max1 = rect1
242
+ x_min2 , y_min2 , x_max2 , y_max2 = rect2
243
+ if x_min1 > x_max2 or x_max1 < x_min2 :
244
+ return False
245
+ if y_min1 > y_max2 or y_max1 < y_min2 :
246
+ return False
247
+ return True
248
+
199
249
200
250
def save_structure_res (res , save_folder , img_name , img_idx = 0 ):
201
251
excel_save_folder = os .path .join (save_folder , img_name )
0 commit comments