diff --git a/paddlex/configs/pipelines/PP-StructureV3.yaml b/paddlex/configs/pipelines/PP-StructureV3.yaml index 72b16b2bb5..2bb44b6733 100644 --- a/paddlex/configs/pipelines/PP-StructureV3.yaml +++ b/paddlex/configs/pipelines/PP-StructureV3.yaml @@ -13,7 +13,9 @@ SubModules: model_name: PP-DocLayout-L model_dir: null threshold: - 7: 0.3 + 0: 0.3 # paragraph_title + 7: 0.3 # formula + 16: 0.3 # seal layout_nms: True layout_unclip_ratio: 1.0 layout_merge_bboxes_mode: @@ -94,6 +96,33 @@ SubPipelines: module_name: table_cells_detection model_name: RT-DETR-L_wireless_table_cell_det model_dir: null + SubPipelines: + GeneralOCR: + pipeline_name: OCR + text_type: general + use_doc_preprocessor: False + use_textline_orientation: True + SubModules: + TextDetection: + module_name: text_detection + model_name: PP-OCRv4_server_det + model_dir: null + limit_side_len: 1200 + limit_type: max + thresh: 0.3 + box_thresh: 0.4 + unclip_ratio: 2.0 + TextLineOrientation: + module_name: textline_orientation + model_name: PP-LCNet_x0_25_textline_ori + model_dir: null + batch_size: 1 + TextRecognition: + module_name: text_recognition + model_name: PP-OCRv4_server_rec_doc + model_dir: null + batch_size: 6 + score_thresh: 0.0 SealRecognition: pipeline_name: seal_recognition diff --git a/paddlex/inference/pipelines/table_recognition/pipeline.py b/paddlex/inference/pipelines/table_recognition/pipeline.py index 6244e77de9..3eb32a74e5 100644 --- a/paddlex/inference/pipelines/table_recognition/pipeline.py +++ b/paddlex/inference/pipelines/table_recognition/pipeline.py @@ -88,6 +88,11 @@ def __init__( {"pipeline_config_error": "config error for general_ocr_pipeline!"}, ) self.general_ocr_pipeline = self.create_pipeline(general_ocr_config) + else: + self.general_ocr_config_bak = config.get("SubPipelines", {}).get( + "GeneralOCR", + None + ) self._crop_by_boxes = CropByBoxes() @@ -217,6 +222,33 @@ def predict_doc_preprocessor_res( doc_preprocessor_res = {} doc_preprocessor_image = image_array return doc_preprocessor_res, doc_preprocessor_image + + def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes): + """ + Splits OCR bounding boxes by table cells and retrieves text. + + Args: + ori_img (ndarray): The original image from which text regions will be extracted. + cells_bboxes (list or ndarray): Detected cell bounding boxes to extract text from. + + Returns: + list: A list containing the recognized texts from each cell. + """ + + # Check if cells_bboxes is a list and convert it if not. + if not isinstance(cells_bboxes, list): + cells_bboxes = cells_bboxes.tolist() + texts_list = [] # Initialize a list to store the recognized texts. + # Process each bounding box provided in cells_bboxes. + for i in range(len(cells_bboxes)): + # Extract and round up the coordinates of the bounding box. + x1, y1, x2, y2 = [math.ceil(k) for k in cells_bboxes[i]] + # Perform OCR on the defined region of the image and get the recognized text. + rec_te = next(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :])) + # Concatenate the texts and append them to the texts_list. + texts_list.append(''.join(rec_te["rec_texts"])) + # Return the list of recognized texts from each cell. + return texts_list def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes): """ @@ -270,15 +302,9 @@ def predict_single_table_recognition_res( """ table_structure_pred = next(self.table_structure_model(image_array)) if use_table_cells_ocr_results == True: - table_cells_result = list( - map(lambda arr: arr.tolist(), table_structure_pred["bbox"]) - ) - table_cells_result = [ - [rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result - ] - cells_texts_list = self.split_ocr_bboxes_by_table_cells( - image_array, table_cells_result - ) + table_cells_result = list(map(lambda arr: arr.tolist(), table_structure_pred["bbox"])) + table_cells_result = [[rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result] + cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result) else: cells_texts_list = [] single_table_recognition_res = get_table_recognition_res( @@ -381,6 +407,9 @@ def predict( text_rec_score_thresh=text_rec_score_thresh, ) ) + elif use_table_cells_ocr_results == True: + assert self.general_ocr_config_bak != None + self.general_ocr_pipeline = self.create_pipeline(self.general_ocr_config_bak) table_res_list = [] table_region_id = 1 diff --git a/paddlex/inference/pipelines/table_recognition/pipeline_v2.py b/paddlex/inference/pipelines/table_recognition/pipeline_v2.py index 885db340a6..1990676645 100644 --- a/paddlex/inference/pipelines/table_recognition/pipeline_v2.py +++ b/paddlex/inference/pipelines/table_recognition/pipeline_v2.py @@ -128,6 +128,11 @@ def __init__( {"pipeline_config_error": "config error for general_ocr_pipeline!"}, ) self.general_ocr_pipeline = self.create_pipeline(general_ocr_config) + else: + self.general_ocr_config_bak = config.get("SubPipelines", {}).get( + "GeneralOCR", + None + ) self._crop_by_boxes = CropByBoxes() @@ -595,25 +600,15 @@ def predict_single_table_recognition_res( use_e2e_model = True else: table_cells_pred = next( - self.wireless_table_cells_detection_model( - image_array, threshold=0.3 - ) + self.wireless_table_cells_detection_model(image_array, threshold=0.3) ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection. # If you really want more or fewer table cells detection boxes, the threshold can be adjusted. if use_e2e_model == False: - table_structure_result = self.extract_results( - table_structure_pred, "table_stru" - ) - table_cells_result, table_cells_score = self.extract_results( - table_cells_pred, "det" - ) - table_cells_result, table_cells_score = self.cells_det_results_nms( - table_cells_result, table_cells_score - ) - ocr_det_boxes = self.get_region_ocr_det_boxes( - overall_ocr_res["rec_boxes"].tolist(), table_box - ) + table_structure_result = self.extract_results(table_structure_pred, "table_stru") + table_cells_result, table_cells_score = self.extract_results(table_cells_pred, "det") + table_cells_result, table_cells_score = self.cells_det_results_nms(table_cells_result, table_cells_score) + ocr_det_boxes = self.get_region_ocr_det_boxes(overall_ocr_res["rec_boxes"].tolist(), table_box) table_cells_result = self.cells_det_results_reprocessing( table_cells_result, table_cells_score, @@ -621,9 +616,7 @@ def predict_single_table_recognition_res( len(table_structure_pred["bbox"]), ) if use_table_cells_ocr_results == True: - cells_texts_list = self.split_ocr_bboxes_by_table_cells( - image_array, table_cells_result - ) + cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result) else: cells_texts_list = [] single_table_recognition_res = get_table_recognition_res( @@ -636,16 +629,9 @@ def predict_single_table_recognition_res( ) else: if use_table_cells_ocr_results == True: - table_cells_result_e2e = list( - map(lambda arr: arr.tolist(), table_structure_pred["bbox"]) - ) - table_cells_result_e2e = [ - [rect[0], rect[1], rect[4], rect[5]] - for rect in table_cells_result_e2e - ] - cells_texts_list = self.split_ocr_bboxes_by_table_cells( - image_array, table_cells_result_e2e - ) + table_cells_result_e2e = list(map(lambda arr: arr.tolist(), table_structure_pred["bbox"])) + table_cells_result_e2e = [[rect[0], rect[1], rect[4], rect[5]]for rect in table_cells_result_e2e] + cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result_e2e) else: cells_texts_list = [] single_table_recognition_res = get_table_recognition_res_e2e( @@ -749,6 +735,9 @@ def predict( text_rec_score_thresh=text_rec_score_thresh, ) ) + elif use_table_cells_ocr_results == True: + assert self.general_ocr_config_bak != None + self.general_ocr_pipeline = self.create_pipeline(self.general_ocr_config_bak) table_res_list = [] table_region_id = 1