Skip to content

Commit a96ba4f

Browse files
committed
Optimize StructureSystem for better OCR accuracy
1 parent e7c0fb9 commit a96ba4f

File tree

2 files changed

+97
-135
lines changed

2 files changed

+97
-135
lines changed

paddleocr.py

+5-93
Original file line numberDiff line numberDiff line change
@@ -634,10 +634,10 @@ def __init__(self, **kwargs):
634634
super().__init__(params)
635635
self.page_num = params.page_num
636636

637-
def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, alpha_color=(255, 255, 255), dt_boxes=None):
637+
def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, alpha_color=(255, 255, 255)):
638638
"""
639639
OCR with PaddleOCR
640-
640+
641641
args:
642642
img: img for OCR, support ndarray, img_path and list or ndarray
643643
det: use text detection or not. If False, only rec will be exec. Default is True
@@ -646,7 +646,6 @@ def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, alpha_col
646646
bin: binarize image to black and white. Default is False.
647647
inv: invert image colors. Default is False.
648648
alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
649-
dt_boxes: user-specified bounding boxes for OCR. If None, the boxes will be detected automatically.
650649
"""
651650
assert isinstance(img, (np.ndarray, list, str, bytes))
652651
if isinstance(img, list) and det == True:
@@ -679,7 +678,7 @@ def preprocess_image(_image):
679678
ocr_res = []
680679
for idx, img in enumerate(imgs):
681680
img = preprocess_image(img)
682-
dt_boxes, rec_res, _ = self.__call__(img, cls, dt_boxes=dt_boxes)
681+
dt_boxes, rec_res, _ = self.__call__(img, cls)
683682
if not dt_boxes and not rec_res:
684683
ocr_res.append(None)
685684
continue
@@ -720,27 +719,6 @@ class PPStructure(StructureSystem):
720719
def __init__(self, **kwargs):
721720
params = parse_args(mMain=False)
722721
params.__dict__.update(**kwargs)
723-
724-
# As reported in issues such as #10270 and #11665, the current
725-
# implementation has problems with the precision of OCR recognition.
726-
#
727-
# To address this issue, here we implement a patch fix by employing a
728-
# combination of PaddleOCR (TextSystem) and StructureSystem.
729-
self._args = params
730-
731-
if self._args.ocr:
732-
# If OCR is enabled, we first initialize the structure engine without
733-
# enabling OCR, and then initialize a standalone OCR engine.
734-
kwargs.pop('ocr', None)
735-
self._init_structure(ocr=False, **kwargs)
736-
self._ocr_engine = PaddleOCR(**kwargs)
737-
else:
738-
# Init the structure engine with the raw parameters.
739-
self._init_structure(**kwargs)
740-
741-
def _init_structure(self, **kwargs):
742-
params = parse_args(mMain=False)
743-
params.__dict__.update(**kwargs)
744722
assert params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION, "structure_version must in {}, but get {}".format(
745723
SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version)
746724
params.use_gpu = check_gpu(params.use_gpu)
@@ -796,78 +774,12 @@ def _init_structure(self, **kwargs):
796774
logger.debug(params)
797775
super().__init__(params)
798776

799-
def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
800-
if not self._args.ocr:
801-
return self._predict_structure(img, return_ocr_result_in_table, img_idx, self._args.alphacolor)
802-
803-
# We first detect all text regions by using the OCR engine.
804-
dt_boxes, elapse = self._ocr_engine.text_detector(img)
805-
806-
# Then do layout analysis by using the structure engine.
807-
result = self._predict_structure(img, return_ocr_result_in_table, img_idx, self._args.alphacolor)
808-
for r in result:
809-
# Ignore tables since they are parsed separately by the internal table model.
810-
if r['type'] == 'table':
811-
continue
812-
813-
# Keep only the regions that intersect with the current bbox.
814-
r_dt_boxes = self._filter_boxes(dt_boxes, r['bbox'])
815-
816-
# Perform OCR recognition on texts within the these regions.
817-
ocr_result = self._ocr_engine.ocr(img,
818-
det=self._args.det,
819-
rec=self._args.rec,
820-
cls=self._args.use_angle_cls,
821-
bin=self._args.binarize,
822-
inv=self._args.invert,
823-
alpha_color=self._args.alphacolor,
824-
dt_boxes=r_dt_boxes)
825-
if ocr_result:
826-
ocr_r = ocr_result[0]
827-
if ocr_r: # Sometimes ocr_r might be None.
828-
r['res'] = [
829-
dict(
830-
text_region=x[0],
831-
text=x[1][0],
832-
confidence=x[1][1],
833-
)
834-
for x in ocr_r
835-
]
836-
837-
# Sort the text boxes in order from top to bottom and from left to right.
838-
from ppstructure.recovery.recovery_to_doc import sorted_layout_boxes
839-
h, w, _ = img.shape
840-
sorted_result = sorted_layout_boxes(result, w)
841-
842-
return sorted_result
843-
844-
def _predict_structure(self, img, return_ocr_result_in_table=False, img_idx=0, alpha_color=(255, 255, 255)):
777+
def __call__(self, img, return_ocr_result_in_table=False, img_idx=0, alpha_color=(255, 255, 255)):
845778
img = check_img(img, alpha_color)
846779
res, _ = super().__call__(
847780
img, return_ocr_result_in_table, img_idx=img_idx)
848781
return res
849782

850-
def _filter_boxes(self, dt_boxes, bbox):
851-
# TODO(RussellLuo): Performance needs improvement?
852-
boxes = []
853-
854-
for idx in range(len(dt_boxes)):
855-
box = dt_boxes[idx]
856-
rect = box[0][0], box[0][1], box[2][0], box[2][1]
857-
if self._has_intersection(bbox, rect):
858-
boxes.append(box.tolist())
859-
860-
return np.array(boxes, np.float32).reshape((len(boxes), 4, 2))
861-
862-
def _has_intersection(self, rect1, rect2):
863-
x_min1, y_min1, x_max1, y_max1 = rect1
864-
x_min2, y_min2, x_max2, y_max2 = rect2
865-
if x_min1 > x_max2 or x_max1 < x_min2:
866-
return False
867-
if y_min1 > y_max2 or y_max1 < y_min2:
868-
return False
869-
return True
870-
871783

872784
def main():
873785
# for cmd
@@ -920,7 +832,7 @@ def main():
920832
outfile = args.output + '/' + img_name + '.txt'
921833
with open(outfile,'w',encoding='utf-8') as f:
922834
f.writelines(lines)
923-
835+
924836
elif args.type == 'structure':
925837
img, flag_gif, flag_pdf = check_and_read(img_path)
926838
if not flag_gif and not flag_pdf:

ppstructure/predict_system.py

+92-42
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def __init__(self, args):
5858
logger.warning(
5959
"When args.layout is false, args.ocr is automatically set to false"
6060
)
61-
args.drop_score = 0
6261
# init model
6362
self.layout_predictor = None
6463
self.text_system = None
@@ -93,6 +92,7 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
9392
'all': 0
9493
}
9594
start = time.time()
95+
9696
if self.image_orientation_predictor is not None:
9797
tic = time.time()
9898
cls_result = self.image_orientation_predictor.predict(
@@ -108,6 +108,7 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
108108
img = cv2.rotate(img, cv_rotate_code[angle])
109109
toc = time.time()
110110
time_dict['image_orientation'] = toc - tic
111+
111112
if self.mode == 'structure':
112113
ori_im = img.copy()
113114
if self.layout_predictor is not None:
@@ -116,6 +117,20 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
116117
else:
117118
h, w = ori_im.shape[:2]
118119
layout_res = [dict(bbox=None, label='table')]
120+
121+
# As reported in issues such as #10270 and #11665, the old
122+
# implementation, which recognizes texts from the layout regions,
123+
# has problems with OCR recognition accuracy.
124+
#
125+
# To enhance the OCR recognition accuracy, we implement a patch fix
126+
# that first detect all text regions by using the text_detector
127+
# and then recognize the texts from the text regions (intersecting
128+
# with the layout regions) by using the text_recognizer.
129+
dt_boxes = []
130+
if self.text_system is not None:
131+
dt_boxes, elapse = self.text_system.text_detector(img)
132+
time_dict['det'] = elapse
133+
119134
res_list = []
120135
for region in layout_res:
121136
res = ''
@@ -126,6 +141,8 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
126141
else:
127142
x1, y1, x2, y2 = 0, 0, w, h
128143
roi_img = ori_im
144+
bbox = [x1, y1, x2, y2]
145+
129146
if region['label'] == 'table':
130147
if self.table_system is not None:
131148
res, table_time_dict = self.table_system(
@@ -136,66 +153,99 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
136153
time_dict['rec'] += table_time_dict['rec']
137154
else:
138155
if self.text_system is not None:
139-
if self.recovery:
140-
wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype)
141-
wht_im[y1:y2, x1:x2, :] = roi_img
142-
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
143-
wht_im)
144-
else:
145-
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
146-
roi_img)
156+
res, ocr_time_dict = self._predict_text(ori_im, roi_img, bbox, dt_boxes)
147157
time_dict['det'] += ocr_time_dict['det']
148158
time_dict['rec'] += ocr_time_dict['rec']
149159

150-
# remove style char,
151-
# when using the recognition model trained on the PubtabNet dataset,
152-
# it will recognize the text format in the table, such as <b>
153-
style_token = [
154-
'<strike>', '<strike>', '<sup>', '</sub>', '<b>',
155-
'</b>', '<sub>', '</sup>', '<overline>',
156-
'</overline>', '<underline>', '</underline>', '<i>',
157-
'</i>'
158-
]
159-
res = []
160-
for box, rec_res in zip(filter_boxes, filter_rec_res):
161-
rec_str, rec_conf = rec_res[0], rec_res[1]
162-
for token in style_token:
163-
if token in rec_str:
164-
rec_str = rec_str.replace(token, '')
165-
if not self.recovery:
166-
box += [x1, y1]
167-
if self.return_word_box:
168-
word_box_content_list, word_box_list = cal_ocr_word_box(rec_str, box, rec_res[2])
169-
res.append({
170-
'text': rec_str,
171-
'confidence': float(rec_conf),
172-
'text_region': box.tolist(),
173-
'text_word': word_box_content_list,
174-
'text_word_region': word_box_list
175-
})
176-
else:
177-
res.append({
178-
'text': rec_str,
179-
'confidence': float(rec_conf),
180-
'text_region': box.tolist()
181-
})
182160
res_list.append({
183161
'type': region['label'].lower(),
184-
'bbox': [x1, y1, x2, y2],
162+
'bbox': bbox,
185163
'img': roi_img,
186164
'res': res,
187165
'img_idx': img_idx
188166
})
167+
189168
end = time.time()
190169
time_dict['all'] = end - start
191170
return res_list, time_dict
171+
192172
elif self.mode == 'kie':
193173
re_res, elapse = self.kie_predictor(img)
194174
time_dict['kie'] = elapse
195175
time_dict['all'] = elapse
196176
return re_res[0], time_dict
177+
197178
return None, None
198179

180+
def _predict_text(self, ori_im, roi_img, bbox, dt_boxes):
181+
x1, y1, x2, y2 = bbox
182+
183+
if self.recovery:
184+
wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype)
185+
wht_im[y1:y2, x1:x2, :] = roi_img
186+
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
187+
wht_im)
188+
else:
189+
# Filter the text regions that intersect with the current bbox.
190+
intersecting_dt_boxes = self._filter_boxes(dt_boxes, bbox)
191+
# Recognize texts from these intersecting text regions.
192+
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
193+
ori_im, dt_boxes=intersecting_dt_boxes)
194+
195+
# remove style char,
196+
# when using the recognition model trained on the PubtabNet dataset,
197+
# it will recognize the text format in the table, such as <b>
198+
style_token = [
199+
'<strike>', '<strike>', '<sup>', '</sub>', '<b>',
200+
'</b>', '<sub>', '</sup>', '<overline>',
201+
'</overline>', '<underline>', '</underline>', '<i>',
202+
'</i>'
203+
]
204+
res = []
205+
for box, rec_res in zip(filter_boxes, filter_rec_res):
206+
rec_str, rec_conf = rec_res[0], rec_res[1]
207+
for token in style_token:
208+
if token in rec_str:
209+
rec_str = rec_str.replace(token, '')
210+
# if not self.recovery:
211+
# box += [x1, y1]
212+
if self.return_word_box:
213+
word_box_content_list, word_box_list = cal_ocr_word_box(rec_str, box, rec_res[2])
214+
res.append({
215+
'text': rec_str,
216+
'confidence': float(rec_conf),
217+
'text_region': box.tolist(),
218+
'text_word': word_box_content_list,
219+
'text_word_region': word_box_list
220+
})
221+
else:
222+
res.append({
223+
'text': rec_str,
224+
'confidence': float(rec_conf),
225+
'text_region': box.tolist()
226+
})
227+
return res, ocr_time_dict
228+
229+
def _filter_boxes(self, dt_boxes, bbox):
230+
boxes = []
231+
232+
for idx in range(len(dt_boxes)):
233+
box = dt_boxes[idx]
234+
rect = box[0][0], box[0][1], box[2][0], box[2][1]
235+
if self._has_intersection(bbox, rect):
236+
boxes.append(box.tolist())
237+
238+
return np.array(boxes, np.float32).reshape((len(boxes), 4, 2))
239+
240+
def _has_intersection(self, rect1, rect2):
241+
x_min1, y_min1, x_max1, y_max1 = rect1
242+
x_min2, y_min2, x_max2, y_max2 = rect2
243+
if x_min1 > x_max2 or x_max1 < x_min2:
244+
return False
245+
if y_min1 > y_max2 or y_max1 < y_min2:
246+
return False
247+
return True
248+
199249

200250
def save_structure_res(res, save_folder, img_name, img_idx=0):
201251
excel_save_folder = os.path.join(save_folder, img_name)

0 commit comments

Comments
 (0)