Skip to content

[cherry-pick] Refine table pipes for layout #3649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
31 changes: 30 additions & 1 deletion paddlex/configs/pipelines/PP-StructureV3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ SubModules:
model_name: PP-DocLayout-L
model_dir: null
threshold:
7: 0.3
0: 0.3 # paragraph_title
7: 0.3 # formula
16: 0.3 # seal
layout_nms: True
layout_unclip_ratio: 1.0
layout_merge_bboxes_mode:
Expand Down Expand Up @@ -94,6 +96,33 @@ SubPipelines:
module_name: table_cells_detection
model_name: RT-DETR-L_wireless_table_cell_det
model_dir: null
SubPipelines:
GeneralOCR:
pipeline_name: OCR
text_type: general
use_doc_preprocessor: False
use_textline_orientation: True
SubModules:
TextDetection:
module_name: text_detection
model_name: PP-OCRv4_server_det
model_dir: null
limit_side_len: 1200
limit_type: max
thresh: 0.3
box_thresh: 0.4
unclip_ratio: 2.0
TextLineOrientation:
module_name: textline_orientation
model_name: PP-LCNet_x0_25_textline_ori
model_dir: null
batch_size: 1
TextRecognition:
module_name: text_recognition
model_name: PP-OCRv4_server_rec_doc
model_dir: null
batch_size: 6
score_thresh: 0.0

SealRecognition:
pipeline_name: seal_recognition
Expand Down
47 changes: 38 additions & 9 deletions paddlex/inference/pipelines/table_recognition/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ def __init__(
{"pipeline_config_error": "config error for general_ocr_pipeline!"},
)
self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
else:
self.general_ocr_config_bak = config.get("SubPipelines", {}).get(
"GeneralOCR",
None
)

self._crop_by_boxes = CropByBoxes()

Expand Down Expand Up @@ -217,6 +222,33 @@ def predict_doc_preprocessor_res(
doc_preprocessor_res = {}
doc_preprocessor_image = image_array
return doc_preprocessor_res, doc_preprocessor_image

def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes):
"""
Splits OCR bounding boxes by table cells and retrieves text.

Args:
ori_img (ndarray): The original image from which text regions will be extracted.
cells_bboxes (list or ndarray): Detected cell bounding boxes to extract text from.

Returns:
list: A list containing the recognized texts from each cell.
"""

# Check if cells_bboxes is a list and convert it if not.
if not isinstance(cells_bboxes, list):
cells_bboxes = cells_bboxes.tolist()
texts_list = [] # Initialize a list to store the recognized texts.
# Process each bounding box provided in cells_bboxes.
for i in range(len(cells_bboxes)):
# Extract and round up the coordinates of the bounding box.
x1, y1, x2, y2 = [math.ceil(k) for k in cells_bboxes[i]]
# Perform OCR on the defined region of the image and get the recognized text.
rec_te = next(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))
# Concatenate the texts and append them to the texts_list.
texts_list.append(''.join(rec_te["rec_texts"]))
# Return the list of recognized texts from each cell.
return texts_list

def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes):
"""
Expand Down Expand Up @@ -270,15 +302,9 @@ def predict_single_table_recognition_res(
"""
table_structure_pred = next(self.table_structure_model(image_array))
if use_table_cells_ocr_results == True:
table_cells_result = list(
map(lambda arr: arr.tolist(), table_structure_pred["bbox"])
)
table_cells_result = [
[rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result
]
cells_texts_list = self.split_ocr_bboxes_by_table_cells(
image_array, table_cells_result
)
table_cells_result = list(map(lambda arr: arr.tolist(), table_structure_pred["bbox"]))
table_cells_result = [[rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result]
cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result)
else:
cells_texts_list = []
single_table_recognition_res = get_table_recognition_res(
Expand Down Expand Up @@ -381,6 +407,9 @@ def predict(
text_rec_score_thresh=text_rec_score_thresh,
)
)
elif use_table_cells_ocr_results == True:
assert self.general_ocr_config_bak != None
self.general_ocr_pipeline = self.create_pipeline(self.general_ocr_config_bak)

table_res_list = []
table_region_id = 1
Expand Down
45 changes: 17 additions & 28 deletions paddlex/inference/pipelines/table_recognition/pipeline_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def __init__(
{"pipeline_config_error": "config error for general_ocr_pipeline!"},
)
self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
else:
self.general_ocr_config_bak = config.get("SubPipelines", {}).get(
"GeneralOCR",
None
)

self._crop_by_boxes = CropByBoxes()

Expand Down Expand Up @@ -595,35 +600,23 @@ def predict_single_table_recognition_res(
use_e2e_model = True
else:
table_cells_pred = next(
self.wireless_table_cells_detection_model(
image_array, threshold=0.3
)
self.wireless_table_cells_detection_model(image_array, threshold=0.3)
) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
# If you really want more or fewer table cells detection boxes, the threshold can be adjusted.

if use_e2e_model == False:
table_structure_result = self.extract_results(
table_structure_pred, "table_stru"
)
table_cells_result, table_cells_score = self.extract_results(
table_cells_pred, "det"
)
table_cells_result, table_cells_score = self.cells_det_results_nms(
table_cells_result, table_cells_score
)
ocr_det_boxes = self.get_region_ocr_det_boxes(
overall_ocr_res["rec_boxes"].tolist(), table_box
)
table_structure_result = self.extract_results(table_structure_pred, "table_stru")
table_cells_result, table_cells_score = self.extract_results(table_cells_pred, "det")
table_cells_result, table_cells_score = self.cells_det_results_nms(table_cells_result, table_cells_score)
ocr_det_boxes = self.get_region_ocr_det_boxes(overall_ocr_res["rec_boxes"].tolist(), table_box)
table_cells_result = self.cells_det_results_reprocessing(
table_cells_result,
table_cells_score,
ocr_det_boxes,
len(table_structure_pred["bbox"]),
)
if use_table_cells_ocr_results == True:
cells_texts_list = self.split_ocr_bboxes_by_table_cells(
image_array, table_cells_result
)
cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result)
else:
cells_texts_list = []
single_table_recognition_res = get_table_recognition_res(
Expand All @@ -636,16 +629,9 @@ def predict_single_table_recognition_res(
)
else:
if use_table_cells_ocr_results == True:
table_cells_result_e2e = list(
map(lambda arr: arr.tolist(), table_structure_pred["bbox"])
)
table_cells_result_e2e = [
[rect[0], rect[1], rect[4], rect[5]]
for rect in table_cells_result_e2e
]
cells_texts_list = self.split_ocr_bboxes_by_table_cells(
image_array, table_cells_result_e2e
)
table_cells_result_e2e = list(map(lambda arr: arr.tolist(), table_structure_pred["bbox"]))
table_cells_result_e2e = [[rect[0], rect[1], rect[4], rect[5]]for rect in table_cells_result_e2e]
cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result_e2e)
else:
cells_texts_list = []
single_table_recognition_res = get_table_recognition_res_e2e(
Expand Down Expand Up @@ -749,6 +735,9 @@ def predict(
text_rec_score_thresh=text_rec_score_thresh,
)
)
elif use_table_cells_ocr_results == True:
assert self.general_ocr_config_bak != None
self.general_ocr_pipeline = self.create_pipeline(self.general_ocr_config_bak)

table_res_list = []
table_region_id = 1
Expand Down