Skip to content

Commit d521203

Browse files
authored
support inline formula embedding & update reference format (#2907)
1 parent f06578e commit d521203

File tree

3 files changed

+50
-22
lines changed

3 files changed

+50
-22
lines changed

paddlex/inference/pipelines_new/layout_parsing/pipeline_v2.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,24 @@ def predict(
336336
self.layout_det_model(doc_preprocessor_image),
337337
)
338338

339+
if model_settings["use_formula_recognition"]:
340+
formula_res_all = next(
341+
self.formula_recognition_pipeline(
342+
doc_preprocessor_image,
343+
use_layout_detection=False,
344+
use_doc_orientation_classify=False,
345+
use_doc_unwarping=False,
346+
layout_det_res=layout_det_res,
347+
),
348+
)
349+
formula_res_list = formula_res_all["formula_res_list"]
350+
else:
351+
formula_res_list = []
352+
353+
for formula_res in formula_res_list:
354+
x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
355+
doc_preprocessor_image[y_min:y_max, x_min:x_max, :] = 255.0
356+
339357
if (
340358
model_settings["use_general_ocr"]
341359
or model_settings["use_table_recognition"]
@@ -351,6 +369,24 @@ def predict(
351369
text_rec_score_thresh=text_rec_score_thresh,
352370
),
353371
)
372+
373+
for formula_res in formula_res_list:
374+
x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
375+
poly_points = [
376+
(x_min, y_min),
377+
(x_max, y_min),
378+
(x_max, y_max),
379+
(x_min, y_max),
380+
]
381+
overall_ocr_res["dt_polys"].append(poly_points)
382+
overall_ocr_res["rec_texts"].append(
383+
f"${formula_res['rec_formula']}$"
384+
)
385+
overall_ocr_res["rec_boxes"] = np.vstack(
386+
(overall_ocr_res["rec_boxes"], [formula_res["dt_polys"]])
387+
)
388+
overall_ocr_res["rec_polys"].append(poly_points)
389+
overall_ocr_res["rec_scores"].append(1)
354390
else:
355391
overall_ocr_res = {}
356392

@@ -398,22 +434,11 @@ def predict(
398434
else:
399435
seal_res_list = []
400436

401-
if model_settings["use_formula_recognition"]:
402-
formula_res_all = next(
403-
self.formula_recognition_pipeline(
404-
doc_preprocessor_image,
405-
use_layout_detection=False,
406-
use_doc_orientation_classify=False,
407-
use_doc_unwarping=False,
408-
layout_det_res=layout_det_res,
409-
),
410-
)
411-
formula_res_list = formula_res_all["formula_res_list"]
412-
else:
413-
formula_res_list = []
414-
415-
for table_res in table_res_list:
416-
table_res["layout_bbox"] = table_res["cell_box_list"][0]
437+
for formula_res in formula_res_list:
438+
x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
439+
doc_preprocessor_image[y_min:y_max, x_min:x_max, :] = formula_res[
440+
"input_img"
441+
]
417442

418443
structure_res = get_structure_res(
419444
overall_ocr_res,

paddlex/inference/pipelines_new/layout_parsing/result_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,11 @@ def format_chart():
395395
return "\n".join(img_tags)
396396

397397
def format_reference():
398-
pattern = r"\[\d+\]"
398+
pattern = r"\s*\[\s*\d+\s*\]\s*"
399399
res = re.sub(
400400
pattern,
401401
lambda match: "\n" + match.group(),
402-
sub_block["reference"],
402+
sub_block["reference"].replace("\n", ""),
403403
)
404404
return "\n" + res
405405

paddlex/inference/pipelines_new/layout_parsing/utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def _sort_box_by_y_projection(layout_bbox, ocr_res, line_height_threshold=0.7):
191191
first_span = line[0]
192192
end_span = line[-1]
193193
if first_span[0][0] - x_min > 20:
194-
first_span[1] = "\n " + first_span[1]
194+
first_span[1] = "\n" + first_span[1]
195195
if x_max - end_span[0][2] > 20:
196196
end_span[1] = end_span[1] + "\n"
197197

@@ -235,13 +235,12 @@ def get_structure_res(
235235
layout_bbox = box_info["coordinate"]
236236
label = box_info["label"]
237237
rec_res = {"boxes": [], "rec_texts": [], "flag": False}
238-
drop_index = []
239238
seg_start_flag = True
240239
seg_end_flag = True
241240

242241
if label == "table":
243242
for i, table_res in enumerate(table_res_list):
244-
if calculate_iou(layout_bbox, table_res["layout_bbox"]) > 0.5:
243+
if calculate_iou(layout_bbox, table_res["cell_box_list"][0]) > 0.5:
245244
structure_boxes.append(
246245
{
247246
"label": label,
@@ -262,7 +261,6 @@ def get_structure_res(
262261
overall_ocr_res["rec_texts"][box_no],
263262
)
264263
rec_res["flag"] = True
265-
drop_index.append(box_no)
266264

267265
if rec_res["flag"]:
268266
rec_res = _sort_box_by_y_projection(layout_bbox, rec_res, 0.7)
@@ -272,6 +270,11 @@ def get_structure_res(
272270
seg_start_flag = False
273271
if layout_bbox[2] - rec_res_end_bbox[2] < 20:
274272
seg_end_flag = False
273+
if label == "formula":
274+
rec_res["rec_texts"] = [
275+
rec_res_text.replace("$", "")
276+
for rec_res_text in rec_res["rec_texts"]
277+
]
275278

276279
if label in ["chart", "image"]:
277280
structure_boxes.append(

0 commit comments

Comments
 (0)