Skip to content

Commit 8b71785

Browse files
authored
table rec code (#11999)
* table rec code * 'fixtableinit' * copyright 2024 * table rec pre-commit * table rec slanet_lcnetv2 doc * table rec slanet_lcnetv2 doc * hwattention fix * tablelabelencode add length item
1 parent 38c0c9e commit 8b71785

File tree

9 files changed

+767
-10
lines changed

9 files changed

+767
-10
lines changed

configs/table/SLANet.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ Train:
101101
size: [488, 488]
102102
- ToCHWImage:
103103
- KeepKeys:
104-
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
104+
keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape']
105105
loader:
106106
shuffle: True
107107
batch_size_per_card: 48
@@ -137,7 +137,7 @@ Eval:
137137
size: [488, 488]
138138
- ToCHWImage:
139139
- KeepKeys:
140-
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
140+
keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape']
141141
loader:
142142
shuffle: False
143143
drop_last: False

configs/table/SLANet_ch.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ Train:
9797
size: [488, 488]
9898
- ToCHWImage:
9999
- KeepKeys:
100-
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
100+
keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape']
101101
loader:
102102
shuffle: True
103103
batch_size_per_card: 48
@@ -133,7 +133,7 @@ Eval:
133133
size: [488, 488]
134134
- ToCHWImage:
135135
- KeepKeys:
136-
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
136+
keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape']
137137
loader:
138138
shuffle: False
139139
drop_last: False

configs/table/SLANet_lcnetv2.yml

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
Global:
2+
use_gpu: true
3+
epoch_num: 50
4+
log_smooth_window: 20
5+
print_batch_step: 10
6+
save_model_dir: ./output/SLANet_lcnetv2
7+
save_epoch_step: 20
8+
# evaluation is run every 1000 iterations after the 0th iteration
9+
eval_batch_step: [0, 1000]
10+
cal_metric_during_train: True
11+
pretrained_model:
12+
checkpoints:
13+
save_inference_dir: ./SLANet_lcnetv2_infer
14+
use_visualdl: False
15+
infer_img: ppstructure/docs/table/table.jpg
16+
# for data or label process
17+
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
18+
character_type: en
19+
max_text_length: &max_text_length 500
20+
box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy'
21+
infer_mode: False
22+
use_sync_bn: True
23+
save_res_path: 'output/infer'
24+
d2s_train_image_shape: [3, -1, -1]
25+
amp_custom_white_list: ['concat', 'elementwise_sub', 'set_value']
26+
27+
Optimizer:
28+
name: Adam
29+
beta1: 0.9
30+
beta2: 0.999
31+
clip_norm: 5.0
32+
lr:
33+
learning_rate: 0.001
34+
regularizer:
35+
name: 'L2'
36+
factor: 0.00000
37+
38+
Architecture:
39+
model_type: table
40+
algorithm: SLANet
41+
Backbone:
42+
name: PPLCNetV2_base
43+
Neck:
44+
name: CSPPAN
45+
out_channels: 96
46+
Head:
47+
name: SLAHead
48+
hidden_size: 256
49+
max_text_length: *max_text_length
50+
loc_reg_num: &loc_reg_num 4
51+
52+
Loss:
53+
name: SLALoss
54+
structure_weight: 1.0
55+
loc_weight: 2.0
56+
loc_loss: smooth_l1
57+
58+
PostProcess:
59+
name: TableLabelDecode
60+
merge_no_span_structure: &merge_no_span_structure True
61+
62+
Metric:
63+
name: TableMetric
64+
main_indicator: acc
65+
compute_bbox_metric: False
66+
loc_reg_num: *loc_reg_num
67+
box_format: *box_format
68+
69+
Train:
70+
dataset:
71+
name: PubTabDataSet
72+
data_dir: ../table_data/pubtabnet/train/
73+
label_file_list: [../table_data/pubtabnet/PubTabNet_2.0.0_train.jsonl]
74+
transforms:
75+
- DecodeImage: # load image
76+
img_mode: BGR
77+
channel_first: False
78+
- TableLabelEncode:
79+
learn_empty_box: False
80+
merge_no_span_structure: *merge_no_span_structure
81+
replace_empty_cell_token: False
82+
loc_reg_num: *loc_reg_num
83+
max_text_length: *max_text_length
84+
- TableBoxEncode:
85+
in_box_format: *box_format
86+
out_box_format: *box_format
87+
- ResizeTableImage:
88+
max_len: 488
89+
- NormalizeImage:
90+
scale: 1./255.
91+
mean: [0.485, 0.456, 0.406]
92+
std: [0.229, 0.224, 0.225]
93+
order: 'hwc'
94+
- PaddingTableImage:
95+
size: [488, 488]
96+
- ToCHWImage:
97+
- KeepKeys:
98+
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape' ]
99+
loader:
100+
shuffle: True
101+
batch_size_per_card: 24
102+
drop_last: True
103+
num_workers: 8
104+
105+
Eval:
106+
dataset:
107+
name: PubTabDataSet
108+
data_dir: ../table_data/pubtabnet/val/
109+
label_file_list: [../table_data/pubtabnet/PubTabNet_2.0.0_val.jsonl]
110+
transforms:
111+
- DecodeImage: # load image
112+
img_mode: BGR
113+
channel_first: False
114+
- TableLabelEncode:
115+
learn_empty_box: False
116+
merge_no_span_structure: *merge_no_span_structure
117+
replace_empty_cell_token: False
118+
loc_reg_num: *loc_reg_num
119+
max_text_length: *max_text_length
120+
- TableBoxEncode:
121+
in_box_format: *box_format
122+
out_box_format: *box_format
123+
- ResizeTableImage:
124+
max_len: 488
125+
- NormalizeImage:
126+
scale: 1./255.
127+
mean: [0.485, 0.456, 0.406]
128+
std: [0.229, 0.224, 0.225]
129+
order: 'hwc'
130+
- PaddingTableImage:
131+
size: [488, 488]
132+
- ToCHWImage:
133+
- KeepKeys:
134+
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
135+
loader:
136+
shuffle: False
137+
drop_last: False
138+
batch_size_per_card: 48
139+
num_workers: 4

doc/doc_ch/algorithm_table_slanet.md

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# 表格识别算法-SLANet-LCNetV2
2+
3+
- [1. 算法简介](#1-算法简介)
4+
- [2. 环境配置](#2-环境配置)
5+
- [3. 模型训练、评估、预测](#3-模型训练评估预测)
6+
- [4. 推理部署](#4-推理部署)
7+
- [4.1 Python推理](#41-python推理)
8+
- [4.2 C++推理部署](#42-c推理部署)
9+
- [4.3 Serving服务化部署](#43-serving服务化部署)
10+
- [4.4 更多推理部署](#44-更多推理部署)
11+
- [5. FAQ](#5-faq)
12+
13+
<a name="1"></a>
14+
## 1. 算法简介
15+
16+
PaddleOCR 算法模型挑战赛 - 赛题二:通用表格识别任务排行榜第一算法。核心思路:
17+
18+
- 1. 改善推理过程,至EOS停止,速度提升3倍
19+
- 2. 升级Backbone为LCNetV2(SSLD版本)
20+
- 3. 行列特征增强模块
21+
- 4. 提升分辨率488至512
22+
- 5. 三阶段训练策略
23+
24+
在PubTabNet表格识别公开数据集上,算法复现效果如下:
25+
26+
|模型|骨干网络|配置文件|acc|
27+
| --- | --- | --- | --- |
28+
|SLANet|LCNetV2|[configs/table/SLANet_lcnetv2.yml](../../configs/table/SLANet_lcnetv2.yml)|76.67%|
29+
30+
31+
<a name="2"></a>
32+
## 2. 环境配置
33+
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
34+
35+
36+
<a name="3"></a>
37+
## 3. 模型训练、评估、预测
38+
39+
上述SLANet_LCNetv2模型使用PubTabNet表格识别公开数据集训练得到,数据集下载可参考 [table_datasets](./dataset/table_datasets.md)
40+
41+
### 启动训练
42+
43+
数据下载完成后,请参考[文本识别教程](./recognition.md)进行训练。PaddleOCR对代码进行了模块化,训练不同的模型只需要**更换配置文件**即可。
44+
45+
训练命令如下:
46+
```shell
47+
# stage1
48+
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/table/SLANet_lcnetv2.yml
49+
# stage2 加载stage1的best model作为预训练模型,学习率调整为0.0001;
50+
# stage3 加载stage2的best model作为预训练模型,不调整学习率,将配置文件中所有的488修改为512.
51+
```
52+
53+
<a name="4"></a>
54+
## 4. 推理部署
55+
56+
<a name="4-1"></a>
57+
### 4.1 Python推理
58+
将训练得到best模型,转换成inference model,可以使用如下命令进行转换:
59+
60+
```shell
61+
# 注意将pretrained_model的路径设置为本地路径。
62+
python3 tools/export_model.py -c configs/table/SLANet_lcnetv2.yml -o Global.pretrained_model=path/best_accuracy Global.save_inference_dir=./inference/slanet_lcnetv2_infer
63+
```
64+
65+
**注意:**
66+
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否为所正确的字典文件。
67+
68+
转换成功后,在目录下有三个文件:
69+
```
70+
./inference/slanet_lcnetv2_infer/
71+
├── inference.pdiparams # 识别inference模型的参数文件
72+
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
73+
└── inference.pdmodel # 识别inference模型的program文件
74+
```
75+
76+
77+
执行如下命令进行模型推理:
78+
79+
```shell
80+
cd ppstructure/
81+
python3.7 table/predict_structure.py --table_model_dir=../inference/slanet_lcnetv2_infer/ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict_ch.txt --image_dir=docs/table/table.jpg --output=../output/table_slanet_lcnetv2 --use_gpu=False --benchmark=True --enable_mkldnn=True
82+
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='docs/table'。
83+
```
84+
85+
执行命令后,上面图像的预测结果(结构信息和表格中每个单元格的坐标)会打印到屏幕上,同时会保存单元格坐标的可视化结果。示例如下:
86+
结果如下:
87+
```shell
88+
[2022/06/16 13:06:54] ppocr INFO: result: ['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'], [[72.17591094970703, 10.759100914001465, 60.29658508300781, 16.6805362701416], [161.85562133789062, 10.884308815002441, 14.9495210647583, 16.727018356323242], [277.79876708984375, 29.54340362548828, 31.490320205688477, 18.143272399902344],
89+
...
90+
[336.11724853515625, 280.3601989746094, 39.456939697265625, 18.121286392211914]]
91+
[2022/06/16 13:06:54] ppocr INFO: save vis result to ./output/table.jpg
92+
[2022/06/16 13:06:54] ppocr INFO: Predict time of docs/table/table.jpg: 17.36806297302246
93+
```
94+
95+
<a name="4-2"></a>
96+
### 4.2 C++推理部署
97+
98+
由于C++预处理后处理还未支持SLANet
99+
100+
<a name="4-3"></a>
101+
### 4.3 Serving服务化部署
102+
103+
暂不支持
104+
105+
<a name="4-4"></a>
106+
### 4.4 更多推理部署
107+
108+
暂不支持
109+
110+
<a name="5"></a>
111+
## 5. FAQ

ppocr/data/imaug/label_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ def __call__(self, data):
708708
structure = self.encode(new_structure)
709709
if structure is None:
710710
return None
711-
711+
data["length"] = len(structure)
712712
structure = [self.start_idx] + structure + [self.end_idx] # add sos abd eos
713713
structure = structure + [self.pad_idx] * (
714714
self._max_text_len - len(structure)

ppocr/losses/table_att_loss.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def __init__(self, structure_weight, loc_weight, loc_loss="mse", **kwargs):
6969
def forward(self, predicts, batch):
7070
structure_probs = predicts["structure_probs"]
7171
structure_targets = batch[1].astype("int64")
72-
structure_targets = structure_targets[:, 1:]
72+
max_len = batch[-2].max()
73+
structure_targets = structure_targets[:, 1 : max_len + 2]
7374

7475
structure_loss = self.loss_func(structure_probs, structure_targets)
7576

@@ -78,8 +79,8 @@ def forward(self, predicts, batch):
7879
loc_preds = predicts["loc_preds"]
7980
loc_targets = batch[2].astype("float32")
8081
loc_targets_mask = batch[3].astype("float32")
81-
loc_targets = loc_targets[:, 1:, :]
82-
loc_targets_mask = loc_targets_mask[:, 1:, :]
82+
loc_targets = loc_targets[:, 1 : max_len + 2]
83+
loc_targets_mask = loc_targets_mask[:, 1 : max_len + 2]
8384

8485
loc_loss = (
8586
F.smooth_l1_loss(

ppocr/modeling/backbones/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def build_backbone(config, model_type):
2525
from .rec_lcnetv3 import PPLCNetV3
2626
from .rec_hgnet import PPHGNet_small
2727
from .rec_vit import ViT
28+
from .det_pp_lcnet_v2 import PPLCNetV2_base
2829
from .rec_repvit import RepSVTR_det
2930

3031
support_dict = [
@@ -35,6 +36,7 @@ def build_backbone(config, model_type):
3536
"PPLCNet",
3637
"PPLCNetV3",
3738
"PPHGNet_small",
39+
"PPLCNetV2_base",
3840
"RepSVTR_det",
3941
]
4042
if model_type == "table":

0 commit comments

Comments
 (0)