diff --git a/configs/table/SLANet.yml b/configs/table/SLANet.yml
index 3f2d6b28eab..4388c5f550c 100644
--- a/configs/table/SLANet.yml
+++ b/configs/table/SLANet.yml
@@ -101,7 +101,7 @@ Train:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
+ keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape']
loader:
shuffle: True
batch_size_per_card: 48
@@ -137,7 +137,7 @@ Eval:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
+ keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape']
loader:
shuffle: False
drop_last: False
diff --git a/configs/table/SLANet_ch.yml b/configs/table/SLANet_ch.yml
index 3b1e5c6bd9d..c16f7efed37 100644
--- a/configs/table/SLANet_ch.yml
+++ b/configs/table/SLANet_ch.yml
@@ -97,7 +97,7 @@ Train:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
+ keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape']
loader:
shuffle: True
batch_size_per_card: 48
@@ -133,7 +133,7 @@ Eval:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
+ keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape']
loader:
shuffle: False
drop_last: False
diff --git a/configs/table/SLANet_lcnetv2.yml b/configs/table/SLANet_lcnetv2.yml
new file mode 100644
index 00000000000..df11acf1455
--- /dev/null
+++ b/configs/table/SLANet_lcnetv2.yml
@@ -0,0 +1,139 @@
+Global:
+ use_gpu: true
+ epoch_num: 50
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/SLANet_lcnetv2
+ save_epoch_step: 20
+ # evaluation is run every 1000 iterations after the 0th iteration
+ eval_batch_step: [0, 1000]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir: ./SLANet_lcnetv2_infer
+ use_visualdl: False
+ infer_img: ppstructure/docs/table/table.jpg
+ # for data or label process
+ character_dict_path: ppocr/utils/dict/table_structure_dict.txt
+ character_type: en
+ max_text_length: &max_text_length 500
+ box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy'
+ infer_mode: False
+ use_sync_bn: True
+ save_res_path: 'output/infer'
+ d2s_train_image_shape: [3, -1, -1]
+ amp_custom_white_list: ['concat', 'elementwise_sub', 'set_value']
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ clip_norm: 5.0
+ lr:
+ learning_rate: 0.001
+ regularizer:
+ name: 'L2'
+ factor: 0.00000
+
+Architecture:
+ model_type: table
+ algorithm: SLANet
+ Backbone:
+ name: PPLCNetV2_base
+ Neck:
+ name: CSPPAN
+ out_channels: 96
+ Head:
+ name: SLAHead
+ hidden_size: 256
+ max_text_length: *max_text_length
+ loc_reg_num: &loc_reg_num 4
+
+Loss:
+ name: SLALoss
+ structure_weight: 1.0
+ loc_weight: 2.0
+ loc_loss: smooth_l1
+
+PostProcess:
+ name: TableLabelDecode
+ merge_no_span_structure: &merge_no_span_structure True
+
+Metric:
+ name: TableMetric
+ main_indicator: acc
+ compute_bbox_metric: False
+ loc_reg_num: *loc_reg_num
+ box_format: *box_format
+
+Train:
+ dataset:
+ name: PubTabDataSet
+ data_dir: ../table_data/pubtabnet/train/
+ label_file_list: [../table_data/pubtabnet/PubTabNet_2.0.0_train.jsonl]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - TableLabelEncode:
+ learn_empty_box: False
+ merge_no_span_structure: *merge_no_span_structure
+ replace_empty_cell_token: False
+ loc_reg_num: *loc_reg_num
+ max_text_length: *max_text_length
+ - TableBoxEncode:
+ in_box_format: *box_format
+ out_box_format: *box_format
+ - ResizeTableImage:
+ max_len: 488
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: 'hwc'
+ - PaddingTableImage:
+ size: [488, 488]
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape' ]
+ loader:
+ shuffle: True
+ batch_size_per_card: 24
+ drop_last: True
+ num_workers: 8
+
+Eval:
+ dataset:
+ name: PubTabDataSet
+ data_dir: ../table_data/pubtabnet/val/
+ label_file_list: [../table_data/pubtabnet/PubTabNet_2.0.0_val.jsonl]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - TableLabelEncode:
+ learn_empty_box: False
+ merge_no_span_structure: *merge_no_span_structure
+ replace_empty_cell_token: False
+ loc_reg_num: *loc_reg_num
+ max_text_length: *max_text_length
+ - TableBoxEncode:
+ in_box_format: *box_format
+ out_box_format: *box_format
+ - ResizeTableImage:
+ max_len: 488
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: 'hwc'
+ - PaddingTableImage:
+ size: [488, 488]
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 48
+ num_workers: 4
diff --git a/doc/doc_ch/algorithm_table_slanet.md b/doc/doc_ch/algorithm_table_slanet.md
new file mode 100644
index 00000000000..4aba89b415c
--- /dev/null
+++ b/doc/doc_ch/algorithm_table_slanet.md
@@ -0,0 +1,111 @@
+# 表格识别算法-SLANet-LCNetV2
+
+- [1. 算法简介](#1-算法简介)
+- [2. 环境配置](#2-环境配置)
+- [3. 模型训练、评估、预测](#3-模型训练评估预测)
+- [4. 推理部署](#4-推理部署)
+ - [4.1 Python推理](#41-python推理)
+ - [4.2 C++推理部署](#42-c推理部署)
+ - [4.3 Serving服务化部署](#43-serving服务化部署)
+ - [4.4 更多推理部署](#44-更多推理部署)
+- [5. FAQ](#5-faq)
+
+
+## 1. 算法简介
+
+PaddleOCR 算法模型挑战赛 - 赛题二:通用表格识别任务排行榜第一算法。核心思路:
+
+- 1. 改善推理过程,至EOS停止,速度提升3倍
+- 2. 升级Backbone为LCNetV2(SSLD版本)
+- 3. 行列特征增强模块
+- 4. 提升分辨率488至512
+- 5. 三阶段训练策略
+
+在PubTabNet表格识别公开数据集上,算法复现效果如下:
+
+|模型|骨干网络|配置文件|acc|
+| --- | --- | --- | --- |
+|SLANet|LCNetV2|[configs/table/SLANet_lcnetv2.yml](../../configs/table/SLANet_lcnetv2.yml)|76.67%|
+
+
+
+## 2. 环境配置
+请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
+
+
+
+## 3. 模型训练、评估、预测
+
+上述SLANet_LCNetv2模型使用PubTabNet表格识别公开数据集训练得到,数据集下载可参考 [table_datasets](./dataset/table_datasets.md)。
+
+### 启动训练
+
+数据下载完成后,请参考[文本识别教程](./recognition.md)进行训练。PaddleOCR对代码进行了模块化,训练不同的模型只需要**更换配置文件**即可。
+
+训练命令如下:
+```shell
+# stage1
+python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/table/SLANet_lcnetv2.yml
+# stage2 加载stage1的best model作为预训练模型,学习率调整为0.0001;
+# stage3 加载stage2的best model作为预训练模型,不调整学习率,将配置文件中所有的488修改为512.
+```
+
+
+## 4. 推理部署
+
+
+### 4.1 Python推理
+将训练得到best模型,转换成inference model,可以使用如下命令进行转换:
+
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/export_model.py -c configs/table/SLANet_lcnetv2.yml -o Global.pretrained_model=path/best_accuracy Global.save_inference_dir=./inference/slanet_lcnetv2_infer
+```
+
+**注意:**
+- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否为所正确的字典文件。
+
+转换成功后,在目录下有三个文件:
+```
+./inference/slanet_lcnetv2_infer/
+ ├── inference.pdiparams # 识别inference模型的参数文件
+ ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
+ └── inference.pdmodel # 识别inference模型的program文件
+```
+
+
+执行如下命令进行模型推理:
+
+```shell
+cd ppstructure/
+python3.7 table/predict_structure.py --table_model_dir=../inference/slanet_lcnetv2_infer/ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict_ch.txt --image_dir=docs/table/table.jpg --output=../output/table_slanet_lcnetv2 --use_gpu=False --benchmark=True --enable_mkldnn=True
+# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='docs/table'。
+```
+
+执行命令后,上面图像的预测结果(结构信息和表格中每个单元格的坐标)会打印到屏幕上,同时会保存单元格坐标的可视化结果。示例如下:
+结果如下:
+```shell
+[2022/06/16 13:06:54] ppocr INFO: result: ['', '
', '', '', '', ' | ', ' | ', ' | ', ' | ', ' | ', '
', '', '', '', ' | ', ' | ', ' | ', ' | ', ' | ', '
', '', ' | ', ' | ', ' | ', ' | ', ' | ', '
', '', ' | ', ' | ', ' | ', ' | ', ' | ', '
', '', ' | ', ' | ', ' | ', ' | ', ' | ', '
', '', ' | ', ' | ', ' | ', ' | ', ' | ', '
', '', ' | ', ' | ', ' | ', ' | ', ' | ', '
', '', ' | ', ' | ', ' | ', ' | ', ' | ', '
', '', ' | ', ' | ', ' | ', ' | ', ' | ', '
', '', ' | ', ' | ', ' | ', ' | ', ' | ', '
', '', ' | ', ' | ', ' | ', ' | ', ' | ', '
', '', ' | ', ' | ', ' | ', ' | ', ' | ', '
', '', ' | ', ' | ', ' | ', ' | ', ' | ', '
', '', ' | ', ' | ', ' | ', ' | ', ' | ', '
', '', ' | ', ' | ', ' | ', ' | ', ' | ', '
', '', ' | ', ' | ', ' | ', ' | ', ' | ', '
', '', '
', '', ''], [[72.17591094970703, 10.759100914001465, 60.29658508300781, 16.6805362701416], [161.85562133789062, 10.884308815002441, 14.9495210647583, 16.727018356323242], [277.79876708984375, 29.54340362548828, 31.490320205688477, 18.143272399902344],
+...
+[336.11724853515625, 280.3601989746094, 39.456939697265625, 18.121286392211914]]
+[2022/06/16 13:06:54] ppocr INFO: save vis result to ./output/table.jpg
+[2022/06/16 13:06:54] ppocr INFO: Predict time of docs/table/table.jpg: 17.36806297302246
+```
+
+
+### 4.2 C++推理部署
+
+由于C++预处理后处理还未支持SLANet
+
+
+### 4.3 Serving服务化部署
+
+暂不支持
+
+
+### 4.4 更多推理部署
+
+暂不支持
+
+
+## 5. FAQ
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index ffb753f7994..39e413b2d23 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -708,7 +708,7 @@ def __call__(self, data):
structure = self.encode(new_structure)
if structure is None:
return None
-
+ data["length"] = len(structure)
structure = [self.start_idx] + structure + [self.end_idx] # add sos abd eos
structure = structure + [self.pad_idx] * (
self._max_text_len - len(structure)
diff --git a/ppocr/losses/table_att_loss.py b/ppocr/losses/table_att_loss.py
index 51160a25db9..5f0e7806eed 100644
--- a/ppocr/losses/table_att_loss.py
+++ b/ppocr/losses/table_att_loss.py
@@ -69,7 +69,8 @@ def __init__(self, structure_weight, loc_weight, loc_loss="mse", **kwargs):
def forward(self, predicts, batch):
structure_probs = predicts["structure_probs"]
structure_targets = batch[1].astype("int64")
- structure_targets = structure_targets[:, 1:]
+ max_len = batch[-2].max()
+ structure_targets = structure_targets[:, 1 : max_len + 2]
structure_loss = self.loss_func(structure_probs, structure_targets)
@@ -78,8 +79,8 @@ def forward(self, predicts, batch):
loc_preds = predicts["loc_preds"]
loc_targets = batch[2].astype("float32")
loc_targets_mask = batch[3].astype("float32")
- loc_targets = loc_targets[:, 1:, :]
- loc_targets_mask = loc_targets_mask[:, 1:, :]
+ loc_targets = loc_targets[:, 1 : max_len + 2]
+ loc_targets_mask = loc_targets_mask[:, 1 : max_len + 2]
loc_loss = (
F.smooth_l1_loss(
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 0b64992b208..81d107c293c 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -25,6 +25,7 @@ def build_backbone(config, model_type):
from .rec_lcnetv3 import PPLCNetV3
from .rec_hgnet import PPHGNet_small
from .rec_vit import ViT
+ from .det_pp_lcnet_v2 import PPLCNetV2_base
from .rec_repvit import RepSVTR_det
support_dict = [
@@ -35,6 +36,7 @@ def build_backbone(config, model_type):
"PPLCNet",
"PPLCNetV3",
"PPHGNet_small",
+ "PPLCNetV2_base",
"RepSVTR_det",
]
if model_type == "table":
diff --git a/ppocr/modeling/backbones/det_pp_lcnet_v2.py b/ppocr/modeling/backbones/det_pp_lcnet_v2.py
new file mode 100644
index 00000000000..5b5a568a254
--- /dev/null
+++ b/ppocr/modeling/backbones/det_pp_lcnet_v2.py
@@ -0,0 +1,358 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import, division, print_function
+import os
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Dropout, Linear
+from paddle.regularizer import L2Decay
+from paddle.nn.initializer import KaimingNormal
+from paddle.utils.download import get_path_from_url
+
+MODEL_URLS = {
+ "PPLCNetV2_small": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_small_ssld_pretrained.pdparams",
+ "PPLCNetV2_base": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_ssld_pretrained.pdparams",
+ "PPLCNetV2_large": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_large_ssld_pretrained.pdparams",
+}
+
+__all__ = list(MODEL_URLS.keys())
+
+NET_CONFIG = {
+ # in_channels, kernel_size, split_pw, use_rep, use_se, use_shortcut
+ "stage1": [64, 3, False, False, False, False],
+ "stage2": [128, 3, False, False, False, False],
+ "stage3": [256, 5, True, True, True, False],
+ "stage4": [512, 5, False, True, False, True],
+}
+
+
+def make_divisible(v, divisor=8, min_value=None):
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(
+ self, in_channels, out_channels, kernel_size, stride, groups=1, use_act=True
+ ):
+ super().__init__()
+ self.use_act = use_act
+ self.conv = Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ bias_attr=False,
+ )
+
+ self.bn = BatchNorm2D(
+ out_channels,
+ weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
+ bias_attr=ParamAttr(regularizer=L2Decay(0.0)),
+ )
+ if self.use_act:
+ self.act = nn.ReLU()
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ if self.use_act:
+ x = self.act(x)
+ return x
+
+
+class SEModule(nn.Layer):
+ def __init__(self, channel, reduction=4):
+ super().__init__()
+ self.avg_pool = AdaptiveAvgPool2D(1)
+ self.conv1 = Conv2D(
+ in_channels=channel,
+ out_channels=channel // reduction,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.relu = nn.ReLU()
+ self.conv2 = Conv2D(
+ in_channels=channel // reduction,
+ out_channels=channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.hardsigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ identity = x
+ x = self.avg_pool(x)
+ x = self.conv1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.hardsigmoid(x)
+ x = paddle.multiply(x=identity, y=x)
+ return x
+
+
+class RepDepthwiseSeparable(nn.Layer):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ dw_size=3,
+ split_pw=False,
+ use_rep=False,
+ use_se=False,
+ use_shortcut=False,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.is_repped = False
+
+ self.dw_size = dw_size
+ self.split_pw = split_pw
+ self.use_rep = use_rep
+ self.use_se = use_se
+ self.use_shortcut = (
+ True
+ if use_shortcut and stride == 1 and in_channels == out_channels
+ else False
+ )
+
+ if self.use_rep:
+ self.dw_conv_list = nn.LayerList()
+ for kernel_size in range(self.dw_size, 0, -2):
+ if kernel_size == 1 and stride != 1:
+ continue
+ dw_conv = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ groups=in_channels,
+ use_act=False,
+ )
+ self.dw_conv_list.append(dw_conv)
+ self.dw_conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=dw_size,
+ stride=stride,
+ padding=(dw_size - 1) // 2,
+ groups=in_channels,
+ )
+ else:
+ self.dw_conv = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=dw_size,
+ stride=stride,
+ groups=in_channels,
+ )
+
+ self.act = nn.ReLU()
+
+ if use_se:
+ self.se = SEModule(in_channels)
+
+ if self.split_pw:
+ pw_ratio = 0.5
+ self.pw_conv_1 = ConvBNLayer(
+ in_channels=in_channels,
+ kernel_size=1,
+ out_channels=int(out_channels * pw_ratio),
+ stride=1,
+ )
+ self.pw_conv_2 = ConvBNLayer(
+ in_channels=int(out_channels * pw_ratio),
+ kernel_size=1,
+ out_channels=out_channels,
+ stride=1,
+ )
+ else:
+ self.pw_conv = ConvBNLayer(
+ in_channels=in_channels,
+ kernel_size=1,
+ out_channels=out_channels,
+ stride=1,
+ )
+
+ def forward(self, x):
+ if self.use_rep:
+ input_x = x
+ if self.is_repped:
+ x = self.act(self.dw_conv(x))
+ else:
+ y = self.dw_conv_list[0](x)
+ for dw_conv in self.dw_conv_list[1:]:
+ y += dw_conv(x)
+ x = self.act(y)
+ else:
+ x = self.dw_conv(x)
+
+ if self.use_se:
+ x = self.se(x)
+ if self.split_pw:
+ x = self.pw_conv_1(x)
+ x = self.pw_conv_2(x)
+ else:
+ x = self.pw_conv(x)
+ if self.use_shortcut:
+ x = x + input_x
+ return x
+
+ def re_parameterize(self):
+ if self.use_rep:
+ self.is_repped = True
+ kernel, bias = self._get_equivalent_kernel_bias()
+ self.dw_conv.weight.set_value(kernel)
+ self.dw_conv.bias.set_value(bias)
+
+ def _get_equivalent_kernel_bias(self):
+ kernel_sum = 0
+ bias_sum = 0
+ for dw_conv in self.dw_conv_list:
+ kernel, bias = self._fuse_bn_tensor(dw_conv)
+ kernel = self._pad_tensor(kernel, to_size=self.dw_size)
+ kernel_sum += kernel
+ bias_sum += bias
+ return kernel_sum, bias_sum
+
+ def _fuse_bn_tensor(self, branch):
+ kernel = branch.conv.weight
+ running_mean = branch.bn._mean
+ running_var = branch.bn._variance
+ gamma = branch.bn.weight
+ beta = branch.bn.bias
+ eps = branch.bn._epsilon
+ std = (running_var + eps).sqrt()
+ t = (gamma / std).reshape((-1, 1, 1, 1))
+ return kernel * t, beta - running_mean * gamma / std
+
+ def _pad_tensor(self, tensor, to_size):
+ from_size = tensor.shape[-1]
+ if from_size == to_size:
+ return tensor
+ pad = (to_size - from_size) // 2
+ return F.pad(tensor, [pad, pad, pad, pad])
+
+
+class PPLCNetV2(nn.Layer):
+ def __init__(self, scale, depths, out_indx=[1, 2, 3, 4], **kwargs):
+ super().__init__(**kwargs)
+ self.scale = scale
+ self.out_channels = [
+ # int(NET_CONFIG["blocks3"][-1][2] * scale),
+ int(NET_CONFIG["stage1"][0] * scale * 2),
+ int(NET_CONFIG["stage2"][0] * scale * 2),
+ int(NET_CONFIG["stage3"][0] * scale * 2),
+ int(NET_CONFIG["stage4"][0] * scale * 2),
+ ]
+ self.stem = nn.Sequential(
+ *[
+ ConvBNLayer(
+ in_channels=3,
+ kernel_size=3,
+ out_channels=make_divisible(32 * scale),
+ stride=2,
+ ),
+ RepDepthwiseSeparable(
+ in_channels=make_divisible(32 * scale),
+ out_channels=make_divisible(64 * scale),
+ stride=1,
+ dw_size=3,
+ ),
+ ]
+ )
+ self.out_indx = out_indx
+ # stages
+ self.stages = nn.LayerList()
+ for depth_idx, k in enumerate(NET_CONFIG):
+ (
+ in_channels,
+ kernel_size,
+ split_pw,
+ use_rep,
+ use_se,
+ use_shortcut,
+ ) = NET_CONFIG[k]
+ self.stages.append(
+ nn.Sequential(
+ *[
+ RepDepthwiseSeparable(
+ in_channels=make_divisible(
+ (in_channels if i == 0 else in_channels * 2) * scale
+ ),
+ out_channels=make_divisible(in_channels * 2 * scale),
+ stride=2 if i == 0 else 1,
+ dw_size=kernel_size,
+ split_pw=split_pw,
+ use_rep=use_rep,
+ use_se=use_se,
+ use_shortcut=use_shortcut,
+ )
+ for i in range(depths[depth_idx])
+ ]
+ )
+ )
+
+ # if pretrained:
+ self._load_pretrained(MODEL_URLS["PPLCNetV2_base"], use_ssld=True)
+
+ def forward(self, x):
+ x = self.stem(x)
+ i = 1
+ outs = []
+ for stage in self.stages:
+ x = stage(x)
+ if i in self.out_indx:
+ outs.append(x)
+ i += 1
+ return outs
+
+ def _load_pretrained(self, pretrained_url, use_ssld=False):
+ print(pretrained_url)
+ local_weight_path = get_path_from_url(
+ pretrained_url, os.path.expanduser("~/.paddleclas/weights")
+ )
+ param_state_dict = paddle.load(local_weight_path)
+ self.set_dict(param_state_dict)
+ print("load pretrain ssd success!")
+ return
+
+
+def PPLCNetV2_base(in_channels=3, **kwargs):
+ """
+ PPLCNetV2_base
+ Args:
+ pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
+ If str, means the path of the pretrained model.
+ use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
+ Returns:
+ model: nn.Layer. Specific `PPLCNetV2_base` model depends on args.
+ """
+ model = PPLCNetV2(scale=1.0, depths=[2, 2, 6, 2], **kwargs)
+ return model
diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py
index d204d5a433a..7e2dc219638 100644
--- a/ppocr/modeling/heads/table_att_head.py
+++ b/ppocr/modeling/heads/table_att_head.py
@@ -24,6 +24,7 @@
import numpy as np
from .rec_att_head import AttentionGRUCell
+from ppocr.modeling.backbones.rec_svtrnet import DropPath, Identity, Mlp
def get_para_bias_attr(l2_decay, k):
@@ -134,6 +135,120 @@ def forward(self, inputs, targets=None):
return {"structure_probs": structure_probs, "loc_preds": loc_preds}
+class HWAttention(nn.Layer):
+ def __init__(
+ self,
+ head_dim=32,
+ qk_scale=None,
+ attn_drop=0.0,
+ ):
+ super().__init__()
+ self.head_dim = head_dim
+ self.scale = qk_scale or self.head_dim**-0.5
+ self.attn_drop = nn.Dropout(attn_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ C = C // 3
+ qkv = x.reshape([B, N, 3, C // self.head_dim, self.head_dim]).transpose(
+ [2, 0, 3, 1, 4]
+ )
+ q, k, v = qkv.unbind(0)
+ attn = q @ k.transpose([0, 1, 3, 2]) * self.scale
+ attn = F.softmax(attn, -1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+ x = x.transpose([0, 2, 1]).reshape([B, N, C])
+ return x
+
+
+def img2windows(img, H_sp, W_sp):
+ """
+ img: B C H W
+ """
+ B, H, W, C = img.shape
+ img_reshape = img.reshape([B, H // H_sp, H_sp, W // W_sp, W_sp, C])
+ img_perm = img_reshape.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, H_sp * W_sp, C])
+ return img_perm
+
+
+def windows2img(img_splits_hw, H_sp, W_sp, H, W):
+ """
+ img_splits_hw: B' H W C
+ """
+ B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp))
+
+ img = img_splits_hw.reshape([B, H // H_sp, W // W_sp, H_sp, W_sp, -1])
+ img = img.transpose([0, 1, 3, 2, 4, 5]).flatten(1, 4)
+ return img
+
+
+class Block(nn.Layer):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ split_h=4,
+ split_w=4,
+ h_num_heads=None,
+ w_num_heads=None,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ eps=1e-6,
+ ):
+ super().__init__()
+ self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+ self.split_h = split_h
+ self.split_w = split_w
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.norm1 = norm_layer(dim, epsilon=eps)
+ self.h_num_heads = h_num_heads if h_num_heads is not None else num_heads // 2
+ self.w_num_heads = w_num_heads if w_num_heads is not None else num_heads // 2
+ self.head_dim = dim // num_heads
+ self.mixer = HWAttention(
+ head_dim=dim // num_heads,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ self.norm2 = norm_layer(dim, epsilon=eps)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ x = x.flatten(2).transpose([0, 2, 1])
+
+ qkv = self.qkv(x).reshape([B, H, W, 3 * C])
+
+ x1 = qkv[:, :, :, : 3 * self.h_num_heads * self.head_dim] # b, h, w, 3ch
+ x2 = qkv[:, :, :, 3 * self.h_num_heads * self.head_dim :] # b, h, w, 3cw
+
+ x1 = self.mixer(img2windows(x1, self.split_h, W)) # b*splith, W, 3ch
+ x2 = self.mixer(img2windows(x2, H, self.split_w)) # b*splitw, h, 3ch
+ x1 = windows2img(x1, self.split_h, W, H, W)
+ x2 = windows2img(x2, H, self.split_w, H, W)
+
+ attened_x = paddle.concat([x1, x2], 2)
+ attened_x = self.proj(attened_x)
+
+ x = self.norm1(x + self.drop_path(attened_x))
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
+ x = x.transpose([0, 2, 1]).reshape([-1, C, H, W])
+ return x
+
+
class SLAHead(nn.Layer):
def __init__(
self,
@@ -143,6 +258,7 @@ def __init__(
max_text_length=500,
loc_reg_num=4,
fc_decay=0.0,
+ use_attn=False,
**kwargs
):
"""
@@ -158,6 +274,7 @@ def __init__(
self.emb = self._char_to_onehot
self.num_embeddings = out_channels
self.loc_reg_num = loc_reg_num
+ self.eos = self.num_embeddings - 1
# structure
self.structure_attention_cell = AttentionGRUCell(
@@ -181,6 +298,21 @@ def __init__(
hidden_size, out_channels, weight_attr=weight_attr, bias_attr=bias_attr
),
)
+ dpr = np.linspace(0, 0.1, 2)
+
+ self.use_attn = use_attn
+ if use_attn:
+ layer_list = [
+ Block(
+ in_channels,
+ num_heads=2,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=dpr[i],
+ )
+ for i in range(2)
+ ]
+ self.cross_atten = nn.Sequential(*layer_list)
# loc
weight_attr1, bias_attr1 = get_para_bias_attr(
l2_decay=fc_decay, k=self.hidden_size
@@ -207,6 +339,8 @@ def __init__(
def forward(self, inputs, targets=None):
fea = inputs[-1]
batch_size = fea.shape[0]
+ if self.use_attn:
+ fea = fea + self.cross_atten(fea)
# reshape
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], -1])
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
@@ -220,15 +354,22 @@ def forward(self, inputs, targets=None):
)
structure_preds.stop_gradient = True
loc_preds.stop_gradient = True
+
if self.training and targets is not None:
structure = targets[0]
- for i in range(self.max_text_length + 1):
+ max_len = targets[-2].max()
+ for i in range(max_len + 1):
hidden, structure_step, loc_step = self._decode(
structure[:, i], fea, hidden
)
structure_preds[:, i, :] = structure_step
loc_preds[:, i, :] = loc_step
+ structure_preds = structure_preds[:, : max_len + 1]
+ loc_preds = loc_preds[:, : max_len + 1]
else:
+ structure_ids = paddle.zeros(
+ (batch_size, self.max_text_length + 1), dtype=paddle.int64
+ )
pre_chars = paddle.zeros(shape=[batch_size], dtype="int32")
max_text_length = paddle.to_tensor(self.max_text_length)
# for export
@@ -238,8 +379,13 @@ def forward(self, inputs, targets=None):
pre_chars = structure_step.argmax(axis=1, dtype="int32")
structure_preds[:, i, :] = structure_step
loc_preds[:, i, :] = loc_step
+
+ structure_ids[:, i] = pre_chars
+ if (structure_ids == self.eos).any(-1).all():
+ break
if not self.training:
- structure_preds = F.softmax(structure_preds)
+ structure_preds = F.softmax(structure_preds[:, : i + 1])
+ loc_preds = loc_preds[:, : i + 1]
return {"structure_probs": structure_preds, "loc_preds": loc_preds}
def _decode(self, pre_chars, features, hidden):