Skip to content

table rec code #11999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 16, 2024
4 changes: 2 additions & 2 deletions configs/table/SLANet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ Train:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape']
loader:
shuffle: True
batch_size_per_card: 48
Expand Down Expand Up @@ -137,7 +137,7 @@ Eval:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape']
loader:
shuffle: False
drop_last: False
Expand Down
4 changes: 2 additions & 2 deletions configs/table/SLANet_ch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Train:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape']
loader:
shuffle: True
batch_size_per_card: 48
Expand Down Expand Up @@ -133,7 +133,7 @@ Eval:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape']
loader:
shuffle: False
drop_last: False
Expand Down
139 changes: 139 additions & 0 deletions configs/table/SLANet_lcnetv2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
Global:
use_gpu: true
epoch_num: 50
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/SLANet_lcnetv2
save_epoch_step: 20
# evaluation is run every 1000 iterations after the 0th iteration
eval_batch_step: [0, 1000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir: ./SLANet_lcnetv2_infer
use_visualdl: False
infer_img: ppstructure/docs/table/table.jpg
# for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en
max_text_length: &max_text_length 500
box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy'
infer_mode: False
use_sync_bn: True
save_res_path: 'output/infer'
d2s_train_image_shape: [3, -1, -1]
amp_custom_white_list: ['concat', 'elementwise_sub', 'set_value']

Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
clip_norm: 5.0
lr:
learning_rate: 0.001
regularizer:
name: 'L2'
factor: 0.00000

Architecture:
model_type: table
algorithm: SLANet
Backbone:
name: PPLCNetV2_base
Neck:
name: CSPPAN
out_channels: 96
Head:
name: SLAHead
hidden_size: 256
max_text_length: *max_text_length
loc_reg_num: &loc_reg_num 4

Loss:
name: SLALoss
structure_weight: 1.0
loc_weight: 2.0
loc_loss: smooth_l1

PostProcess:
name: TableLabelDecode
merge_no_span_structure: &merge_no_span_structure True

Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: False
loc_reg_num: *loc_reg_num
box_format: *box_format

Train:
dataset:
name: PubTabDataSet
data_dir: ../table_data/pubtabnet/train/
label_file_list: [../table_data/pubtabnet/PubTabNet_2.0.0_train.jsonl]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: False
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
- TableBoxEncode:
in_box_format: *box_format
out_box_format: *box_format
- ResizeTableImage:
max_len: 488
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- PaddingTableImage:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape' ]
loader:
shuffle: True
batch_size_per_card: 24
drop_last: True
num_workers: 8

Eval:
dataset:
name: PubTabDataSet
data_dir: ../table_data/pubtabnet/val/
label_file_list: [../table_data/pubtabnet/PubTabNet_2.0.0_val.jsonl]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: False
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
- TableBoxEncode:
in_box_format: *box_format
out_box_format: *box_format
- ResizeTableImage:
max_len: 488
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- PaddingTableImage:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader:
shuffle: False
drop_last: False
batch_size_per_card: 48
num_workers: 4
111 changes: 111 additions & 0 deletions doc/doc_ch/algorithm_table_slanet.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# 表格识别算法-SLANet-LCNetV2

- [1. 算法简介](#1-算法简介)
- [2. 环境配置](#2-环境配置)
- [3. 模型训练、评估、预测](#3-模型训练评估预测)
- [4. 推理部署](#4-推理部署)
- [4.1 Python推理](#41-python推理)
- [4.2 C++推理部署](#42-c推理部署)
- [4.3 Serving服务化部署](#43-serving服务化部署)
- [4.4 更多推理部署](#44-更多推理部署)
- [5. FAQ](#5-faq)

<a name="1"></a>
## 1. 算法简介

PaddleOCR 算法模型挑战赛 - 赛题二:通用表格识别任务排行榜第一算法。核心思路:

- 1. 改善推理过程,至EOS停止,速度提升3倍
- 2. 升级Backbone为LCNetV2(SSLD版本)
- 3. 行列特征增强模块
- 4. 提升分辨率488至512
- 5. 三阶段训练策略

在PubTabNet表格识别公开数据集上,算法复现效果如下:

|模型|骨干网络|配置文件|acc|
| --- | --- | --- | --- |
|SLANet|LCNetV2|[configs/table/SLANet_lcnetv2.yml](../../configs/table/SLANet_lcnetv2.yml)|76.67%|


<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。


<a name="3"></a>
## 3. 模型训练、评估、预测

上述SLANet_LCNetv2模型使用PubTabNet表格识别公开数据集训练得到,数据集下载可参考 [table_datasets](./dataset/table_datasets.md)。

### 启动训练

数据下载完成后,请参考[文本识别教程](./recognition.md)进行训练。PaddleOCR对代码进行了模块化,训练不同的模型只需要**更换配置文件**即可。

训练命令如下:
```shell
# stage1
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/table/SLANet_lcnetv2.yml
# stage2 加载stage1的best model作为预训练模型,学习率调整为0.0001;
# stage3 加载stage2的best model作为预训练模型,不调整学习率,将配置文件中所有的488修改为512.
```

<a name="4"></a>
## 4. 推理部署

<a name="4-1"></a>
### 4.1 Python推理
将训练得到best模型,转换成inference model,可以使用如下命令进行转换:

```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py -c configs/table/SLANet_lcnetv2.yml -o Global.pretrained_model=path/best_accuracy Global.save_inference_dir=./inference/slanet_lcnetv2_infer
```

**注意:**
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否为所正确的字典文件。

转换成功后,在目录下有三个文件:
```
./inference/slanet_lcnetv2_infer/
├── inference.pdiparams # 识别inference模型的参数文件
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
└── inference.pdmodel # 识别inference模型的program文件
```


执行如下命令进行模型推理:

```shell
cd ppstructure/
python3.7 table/predict_structure.py --table_model_dir=../inference/slanet_lcnetv2_infer/ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict_ch.txt --image_dir=docs/table/table.jpg --output=../output/table_slanet_lcnetv2 --use_gpu=False --benchmark=True --enable_mkldnn=True
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='docs/table'。
```

执行命令后,上面图像的预测结果(结构信息和表格中每个单元格的坐标)会打印到屏幕上,同时会保存单元格坐标的可视化结果。示例如下:
结果如下:
```shell
[2022/06/16 13:06:54] ppocr INFO: result: ['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'], [[72.17591094970703, 10.759100914001465, 60.29658508300781, 16.6805362701416], [161.85562133789062, 10.884308815002441, 14.9495210647583, 16.727018356323242], [277.79876708984375, 29.54340362548828, 31.490320205688477, 18.143272399902344],
...
[336.11724853515625, 280.3601989746094, 39.456939697265625, 18.121286392211914]]
[2022/06/16 13:06:54] ppocr INFO: save vis result to ./output/table.jpg
[2022/06/16 13:06:54] ppocr INFO: Predict time of docs/table/table.jpg: 17.36806297302246
```

<a name="4-2"></a>
### 4.2 C++推理部署

由于C++预处理后处理还未支持SLANet

<a name="4-3"></a>
### 4.3 Serving服务化部署

暂不支持

<a name="4-4"></a>
### 4.4 更多推理部署

暂不支持

<a name="5"></a>
## 5. FAQ
2 changes: 1 addition & 1 deletion ppocr/data/imaug/label_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ def __call__(self, data):
structure = self.encode(new_structure)
if structure is None:
return None

data["length"] = len(structure)
structure = [self.start_idx] + structure + [self.end_idx] # add sos abd eos
structure = structure + [self.pad_idx] * (
self._max_text_len - len(structure)
Expand Down
7 changes: 4 additions & 3 deletions ppocr/losses/table_att_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(self, structure_weight, loc_weight, loc_loss="mse", **kwargs):
def forward(self, predicts, batch):
structure_probs = predicts["structure_probs"]
structure_targets = batch[1].astype("int64")
structure_targets = structure_targets[:, 1:]
max_len = batch[-2].max()
structure_targets = structure_targets[:, 1 : max_len + 2]

structure_loss = self.loss_func(structure_probs, structure_targets)

Expand All @@ -78,8 +79,8 @@ def forward(self, predicts, batch):
loc_preds = predicts["loc_preds"]
loc_targets = batch[2].astype("float32")
loc_targets_mask = batch[3].astype("float32")
loc_targets = loc_targets[:, 1:, :]
loc_targets_mask = loc_targets_mask[:, 1:, :]
loc_targets = loc_targets[:, 1 : max_len + 2]
loc_targets_mask = loc_targets_mask[:, 1 : max_len + 2]

loc_loss = (
F.smooth_l1_loss(
Expand Down
2 changes: 2 additions & 0 deletions ppocr/modeling/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def build_backbone(config, model_type):
from .rec_lcnetv3 import PPLCNetV3
from .rec_hgnet import PPHGNet_small
from .rec_vit import ViT
from .det_pp_lcnet_v2 import PPLCNetV2_base
from .rec_repvit import RepSVTR_det

support_dict = [
Expand All @@ -35,6 +36,7 @@ def build_backbone(config, model_type):
"PPLCNet",
"PPLCNetV3",
"PPHGNet_small",
"PPLCNetV2_base",
"RepSVTR_det",
]
if model_type == "table":
Expand Down
Loading
Loading