diff --git a/configs/table/SLANet.yml b/configs/table/SLANet.yml index 3f2d6b28eab..4388c5f550c 100644 --- a/configs/table/SLANet.yml +++ b/configs/table/SLANet.yml @@ -101,7 +101,7 @@ Train: size: [488, 488] - ToCHWImage: - KeepKeys: - keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ] + keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape'] loader: shuffle: True batch_size_per_card: 48 @@ -137,7 +137,7 @@ Eval: size: [488, 488] - ToCHWImage: - KeepKeys: - keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ] + keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape'] loader: shuffle: False drop_last: False diff --git a/configs/table/SLANet_ch.yml b/configs/table/SLANet_ch.yml index 3b1e5c6bd9d..c16f7efed37 100644 --- a/configs/table/SLANet_ch.yml +++ b/configs/table/SLANet_ch.yml @@ -97,7 +97,7 @@ Train: size: [488, 488] - ToCHWImage: - KeepKeys: - keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ] + keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape'] loader: shuffle: True batch_size_per_card: 48 @@ -133,7 +133,7 @@ Eval: size: [488, 488] - ToCHWImage: - KeepKeys: - keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ] + keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape'] loader: shuffle: False drop_last: False diff --git a/configs/table/SLANet_lcnetv2.yml b/configs/table/SLANet_lcnetv2.yml new file mode 100644 index 00000000000..df11acf1455 --- /dev/null +++ b/configs/table/SLANet_lcnetv2.yml @@ -0,0 +1,139 @@ +Global: + use_gpu: true + epoch_num: 50 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/SLANet_lcnetv2 + save_epoch_step: 20 + # evaluation is run every 1000 iterations after the 0th iteration + eval_batch_step: [0, 1000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: ./SLANet_lcnetv2_infer + use_visualdl: False + infer_img: ppstructure/docs/table/table.jpg + # for data or label process + character_dict_path: ppocr/utils/dict/table_structure_dict.txt + character_type: en + max_text_length: &max_text_length 500 + box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy' + infer_mode: False + use_sync_bn: True + save_res_path: 'output/infer' + d2s_train_image_shape: [3, -1, -1] + amp_custom_white_list: ['concat', 'elementwise_sub', 'set_value'] + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + clip_norm: 5.0 + lr: + learning_rate: 0.001 + regularizer: + name: 'L2' + factor: 0.00000 + +Architecture: + model_type: table + algorithm: SLANet + Backbone: + name: PPLCNetV2_base + Neck: + name: CSPPAN + out_channels: 96 + Head: + name: SLAHead + hidden_size: 256 + max_text_length: *max_text_length + loc_reg_num: &loc_reg_num 4 + +Loss: + name: SLALoss + structure_weight: 1.0 + loc_weight: 2.0 + loc_loss: smooth_l1 + +PostProcess: + name: TableLabelDecode + merge_no_span_structure: &merge_no_span_structure True + +Metric: + name: TableMetric + main_indicator: acc + compute_bbox_metric: False + loc_reg_num: *loc_reg_num + box_format: *box_format + +Train: + dataset: + name: PubTabDataSet + data_dir: ../table_data/pubtabnet/train/ + label_file_list: [../table_data/pubtabnet/PubTabNet_2.0.0_train.jsonl] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - TableLabelEncode: + learn_empty_box: False + merge_no_span_structure: *merge_no_span_structure + replace_empty_cell_token: False + loc_reg_num: *loc_reg_num + max_text_length: *max_text_length + - TableBoxEncode: + in_box_format: *box_format + out_box_format: *box_format + - ResizeTableImage: + max_len: 488 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - PaddingTableImage: + size: [488, 488] + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape' ] + loader: + shuffle: True + batch_size_per_card: 24 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: PubTabDataSet + data_dir: ../table_data/pubtabnet/val/ + label_file_list: [../table_data/pubtabnet/PubTabNet_2.0.0_val.jsonl] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - TableLabelEncode: + learn_empty_box: False + merge_no_span_structure: *merge_no_span_structure + replace_empty_cell_token: False + loc_reg_num: *loc_reg_num + max_text_length: *max_text_length + - TableBoxEncode: + in_box_format: *box_format + out_box_format: *box_format + - ResizeTableImage: + max_len: 488 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - PaddingTableImage: + size: [488, 488] + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 48 + num_workers: 4 diff --git a/doc/doc_ch/algorithm_table_slanet.md b/doc/doc_ch/algorithm_table_slanet.md new file mode 100644 index 00000000000..4aba89b415c --- /dev/null +++ b/doc/doc_ch/algorithm_table_slanet.md @@ -0,0 +1,111 @@ +# 表格识别算法-SLANet-LCNetV2 + +- [1. 算法简介](#1-算法简介) +- [2. 环境配置](#2-环境配置) +- [3. 模型训练、评估、预测](#3-模型训练评估预测) +- [4. 推理部署](#4-推理部署) + - [4.1 Python推理](#41-python推理) + - [4.2 C++推理部署](#42-c推理部署) + - [4.3 Serving服务化部署](#43-serving服务化部署) + - [4.4 更多推理部署](#44-更多推理部署) +- [5. FAQ](#5-faq) + + +## 1. 算法简介 + +PaddleOCR 算法模型挑战赛 - 赛题二:通用表格识别任务排行榜第一算法。核心思路: + +- 1. 改善推理过程,至EOS停止,速度提升3倍 +- 2. 升级Backbone为LCNetV2(SSLD版本) +- 3. 行列特征增强模块 +- 4. 提升分辨率488至512 +- 5. 三阶段训练策略 + +在PubTabNet表格识别公开数据集上,算法复现效果如下: + +|模型|骨干网络|配置文件|acc| +| --- | --- | --- | --- | +|SLANet|LCNetV2|[configs/table/SLANet_lcnetv2.yml](../../configs/table/SLANet_lcnetv2.yml)|76.67%| + + + +## 2. 环境配置 +请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 + + + +## 3. 模型训练、评估、预测 + +上述SLANet_LCNetv2模型使用PubTabNet表格识别公开数据集训练得到,数据集下载可参考 [table_datasets](./dataset/table_datasets.md)。 + +### 启动训练 + +数据下载完成后,请参考[文本识别教程](./recognition.md)进行训练。PaddleOCR对代码进行了模块化,训练不同的模型只需要**更换配置文件**即可。 + +训练命令如下: +```shell +# stage1 +python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/table/SLANet_lcnetv2.yml +# stage2 加载stage1的best model作为预训练模型,学习率调整为0.0001; +# stage3 加载stage2的best model作为预训练模型,不调整学习率,将配置文件中所有的488修改为512. +``` + + +## 4. 推理部署 + + +### 4.1 Python推理 +将训练得到best模型,转换成inference model,可以使用如下命令进行转换: + +```shell +# 注意将pretrained_model的路径设置为本地路径。 +python3 tools/export_model.py -c configs/table/SLANet_lcnetv2.yml -o Global.pretrained_model=path/best_accuracy Global.save_inference_dir=./inference/slanet_lcnetv2_infer +``` + +**注意:** +- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否为所正确的字典文件。 + +转换成功后,在目录下有三个文件: +``` +./inference/slanet_lcnetv2_infer/ + ├── inference.pdiparams # 识别inference模型的参数文件 + ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略 + └── inference.pdmodel # 识别inference模型的program文件 +``` + + +执行如下命令进行模型推理: + +```shell +cd ppstructure/ +python3.7 table/predict_structure.py --table_model_dir=../inference/slanet_lcnetv2_infer/ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict_ch.txt --image_dir=docs/table/table.jpg --output=../output/table_slanet_lcnetv2 --use_gpu=False --benchmark=True --enable_mkldnn=True +# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='docs/table'。 +``` + +执行命令后,上面图像的预测结果(结构信息和表格中每个单元格的坐标)会打印到屏幕上,同时会保存单元格坐标的可视化结果。示例如下: +结果如下: +```shell +[2022/06/16 13:06:54] ppocr INFO: result: ['', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '
', '', ''], [[72.17591094970703, 10.759100914001465, 60.29658508300781, 16.6805362701416], [161.85562133789062, 10.884308815002441, 14.9495210647583, 16.727018356323242], [277.79876708984375, 29.54340362548828, 31.490320205688477, 18.143272399902344], +... +[336.11724853515625, 280.3601989746094, 39.456939697265625, 18.121286392211914]] +[2022/06/16 13:06:54] ppocr INFO: save vis result to ./output/table.jpg +[2022/06/16 13:06:54] ppocr INFO: Predict time of docs/table/table.jpg: 17.36806297302246 +``` + + +### 4.2 C++推理部署 + +由于C++预处理后处理还未支持SLANet + + +### 4.3 Serving服务化部署 + +暂不支持 + + +### 4.4 更多推理部署 + +暂不支持 + + +## 5. FAQ diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index ffb753f7994..39e413b2d23 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -708,7 +708,7 @@ def __call__(self, data): structure = self.encode(new_structure) if structure is None: return None - + data["length"] = len(structure) structure = [self.start_idx] + structure + [self.end_idx] # add sos abd eos structure = structure + [self.pad_idx] * ( self._max_text_len - len(structure) diff --git a/ppocr/losses/table_att_loss.py b/ppocr/losses/table_att_loss.py index 51160a25db9..5f0e7806eed 100644 --- a/ppocr/losses/table_att_loss.py +++ b/ppocr/losses/table_att_loss.py @@ -69,7 +69,8 @@ def __init__(self, structure_weight, loc_weight, loc_loss="mse", **kwargs): def forward(self, predicts, batch): structure_probs = predicts["structure_probs"] structure_targets = batch[1].astype("int64") - structure_targets = structure_targets[:, 1:] + max_len = batch[-2].max() + structure_targets = structure_targets[:, 1 : max_len + 2] structure_loss = self.loss_func(structure_probs, structure_targets) @@ -78,8 +79,8 @@ def forward(self, predicts, batch): loc_preds = predicts["loc_preds"] loc_targets = batch[2].astype("float32") loc_targets_mask = batch[3].astype("float32") - loc_targets = loc_targets[:, 1:, :] - loc_targets_mask = loc_targets_mask[:, 1:, :] + loc_targets = loc_targets[:, 1 : max_len + 2] + loc_targets_mask = loc_targets_mask[:, 1 : max_len + 2] loc_loss = ( F.smooth_l1_loss( diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 0b64992b208..81d107c293c 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -25,6 +25,7 @@ def build_backbone(config, model_type): from .rec_lcnetv3 import PPLCNetV3 from .rec_hgnet import PPHGNet_small from .rec_vit import ViT + from .det_pp_lcnet_v2 import PPLCNetV2_base from .rec_repvit import RepSVTR_det support_dict = [ @@ -35,6 +36,7 @@ def build_backbone(config, model_type): "PPLCNet", "PPLCNetV3", "PPHGNet_small", + "PPLCNetV2_base", "RepSVTR_det", ] if model_type == "table": diff --git a/ppocr/modeling/backbones/det_pp_lcnet_v2.py b/ppocr/modeling/backbones/det_pp_lcnet_v2.py new file mode 100644 index 00000000000..5b5a568a254 --- /dev/null +++ b/ppocr/modeling/backbones/det_pp_lcnet_v2.py @@ -0,0 +1,358 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division, print_function +import os + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Dropout, Linear +from paddle.regularizer import L2Decay +from paddle.nn.initializer import KaimingNormal +from paddle.utils.download import get_path_from_url + +MODEL_URLS = { + "PPLCNetV2_small": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_small_ssld_pretrained.pdparams", + "PPLCNetV2_base": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_ssld_pretrained.pdparams", + "PPLCNetV2_large": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_large_ssld_pretrained.pdparams", +} + +__all__ = list(MODEL_URLS.keys()) + +NET_CONFIG = { + # in_channels, kernel_size, split_pw, use_rep, use_se, use_shortcut + "stage1": [64, 3, False, False, False, False], + "stage2": [128, 3, False, False, False, False], + "stage3": [256, 5, True, True, True, False], + "stage4": [512, 5, False, True, False, True], +} + + +def make_divisible(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNLayer(nn.Layer): + def __init__( + self, in_channels, out_channels, kernel_size, stride, groups=1, use_act=True + ): + super().__init__() + self.use_act = use_act + self.conv = Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(initializer=KaimingNormal()), + bias_attr=False, + ) + + self.bn = BatchNorm2D( + out_channels, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)), + ) + if self.use_act: + self.act = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.use_act: + x = self.act(x) + return x + + +class SEModule(nn.Layer): + def __init__(self, channel, reduction=4): + super().__init__() + self.avg_pool = AdaptiveAvgPool2D(1) + self.conv1 = Conv2D( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + stride=1, + padding=0, + ) + self.relu = nn.ReLU() + self.conv2 = Conv2D( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + stride=1, + padding=0, + ) + self.hardsigmoid = nn.Sigmoid() + + def forward(self, x): + identity = x + x = self.avg_pool(x) + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.hardsigmoid(x) + x = paddle.multiply(x=identity, y=x) + return x + + +class RepDepthwiseSeparable(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + stride, + dw_size=3, + split_pw=False, + use_rep=False, + use_se=False, + use_shortcut=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.is_repped = False + + self.dw_size = dw_size + self.split_pw = split_pw + self.use_rep = use_rep + self.use_se = use_se + self.use_shortcut = ( + True + if use_shortcut and stride == 1 and in_channels == out_channels + else False + ) + + if self.use_rep: + self.dw_conv_list = nn.LayerList() + for kernel_size in range(self.dw_size, 0, -2): + if kernel_size == 1 and stride != 1: + continue + dw_conv = ConvBNLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=stride, + groups=in_channels, + use_act=False, + ) + self.dw_conv_list.append(dw_conv) + self.dw_conv = nn.Conv2D( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=dw_size, + stride=stride, + padding=(dw_size - 1) // 2, + groups=in_channels, + ) + else: + self.dw_conv = ConvBNLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=dw_size, + stride=stride, + groups=in_channels, + ) + + self.act = nn.ReLU() + + if use_se: + self.se = SEModule(in_channels) + + if self.split_pw: + pw_ratio = 0.5 + self.pw_conv_1 = ConvBNLayer( + in_channels=in_channels, + kernel_size=1, + out_channels=int(out_channels * pw_ratio), + stride=1, + ) + self.pw_conv_2 = ConvBNLayer( + in_channels=int(out_channels * pw_ratio), + kernel_size=1, + out_channels=out_channels, + stride=1, + ) + else: + self.pw_conv = ConvBNLayer( + in_channels=in_channels, + kernel_size=1, + out_channels=out_channels, + stride=1, + ) + + def forward(self, x): + if self.use_rep: + input_x = x + if self.is_repped: + x = self.act(self.dw_conv(x)) + else: + y = self.dw_conv_list[0](x) + for dw_conv in self.dw_conv_list[1:]: + y += dw_conv(x) + x = self.act(y) + else: + x = self.dw_conv(x) + + if self.use_se: + x = self.se(x) + if self.split_pw: + x = self.pw_conv_1(x) + x = self.pw_conv_2(x) + else: + x = self.pw_conv(x) + if self.use_shortcut: + x = x + input_x + return x + + def re_parameterize(self): + if self.use_rep: + self.is_repped = True + kernel, bias = self._get_equivalent_kernel_bias() + self.dw_conv.weight.set_value(kernel) + self.dw_conv.bias.set_value(bias) + + def _get_equivalent_kernel_bias(self): + kernel_sum = 0 + bias_sum = 0 + for dw_conv in self.dw_conv_list: + kernel, bias = self._fuse_bn_tensor(dw_conv) + kernel = self._pad_tensor(kernel, to_size=self.dw_size) + kernel_sum += kernel + bias_sum += bias + return kernel_sum, bias_sum + + def _fuse_bn_tensor(self, branch): + kernel = branch.conv.weight + running_mean = branch.bn._mean + running_var = branch.bn._variance + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn._epsilon + std = (running_var + eps).sqrt() + t = (gamma / std).reshape((-1, 1, 1, 1)) + return kernel * t, beta - running_mean * gamma / std + + def _pad_tensor(self, tensor, to_size): + from_size = tensor.shape[-1] + if from_size == to_size: + return tensor + pad = (to_size - from_size) // 2 + return F.pad(tensor, [pad, pad, pad, pad]) + + +class PPLCNetV2(nn.Layer): + def __init__(self, scale, depths, out_indx=[1, 2, 3, 4], **kwargs): + super().__init__(**kwargs) + self.scale = scale + self.out_channels = [ + # int(NET_CONFIG["blocks3"][-1][2] * scale), + int(NET_CONFIG["stage1"][0] * scale * 2), + int(NET_CONFIG["stage2"][0] * scale * 2), + int(NET_CONFIG["stage3"][0] * scale * 2), + int(NET_CONFIG["stage4"][0] * scale * 2), + ] + self.stem = nn.Sequential( + *[ + ConvBNLayer( + in_channels=3, + kernel_size=3, + out_channels=make_divisible(32 * scale), + stride=2, + ), + RepDepthwiseSeparable( + in_channels=make_divisible(32 * scale), + out_channels=make_divisible(64 * scale), + stride=1, + dw_size=3, + ), + ] + ) + self.out_indx = out_indx + # stages + self.stages = nn.LayerList() + for depth_idx, k in enumerate(NET_CONFIG): + ( + in_channels, + kernel_size, + split_pw, + use_rep, + use_se, + use_shortcut, + ) = NET_CONFIG[k] + self.stages.append( + nn.Sequential( + *[ + RepDepthwiseSeparable( + in_channels=make_divisible( + (in_channels if i == 0 else in_channels * 2) * scale + ), + out_channels=make_divisible(in_channels * 2 * scale), + stride=2 if i == 0 else 1, + dw_size=kernel_size, + split_pw=split_pw, + use_rep=use_rep, + use_se=use_se, + use_shortcut=use_shortcut, + ) + for i in range(depths[depth_idx]) + ] + ) + ) + + # if pretrained: + self._load_pretrained(MODEL_URLS["PPLCNetV2_base"], use_ssld=True) + + def forward(self, x): + x = self.stem(x) + i = 1 + outs = [] + for stage in self.stages: + x = stage(x) + if i in self.out_indx: + outs.append(x) + i += 1 + return outs + + def _load_pretrained(self, pretrained_url, use_ssld=False): + print(pretrained_url) + local_weight_path = get_path_from_url( + pretrained_url, os.path.expanduser("~/.paddleclas/weights") + ) + param_state_dict = paddle.load(local_weight_path) + self.set_dict(param_state_dict) + print("load pretrain ssd success!") + return + + +def PPLCNetV2_base(in_channels=3, **kwargs): + """ + PPLCNetV2_base + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `PPLCNetV2_base` model depends on args. + """ + model = PPLCNetV2(scale=1.0, depths=[2, 2, 6, 2], **kwargs) + return model diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py index d204d5a433a..7e2dc219638 100644 --- a/ppocr/modeling/heads/table_att_head.py +++ b/ppocr/modeling/heads/table_att_head.py @@ -24,6 +24,7 @@ import numpy as np from .rec_att_head import AttentionGRUCell +from ppocr.modeling.backbones.rec_svtrnet import DropPath, Identity, Mlp def get_para_bias_attr(l2_decay, k): @@ -134,6 +135,120 @@ def forward(self, inputs, targets=None): return {"structure_probs": structure_probs, "loc_preds": loc_preds} +class HWAttention(nn.Layer): + def __init__( + self, + head_dim=32, + qk_scale=None, + attn_drop=0.0, + ): + super().__init__() + self.head_dim = head_dim + self.scale = qk_scale or self.head_dim**-0.5 + self.attn_drop = nn.Dropout(attn_drop) + + def forward(self, x): + B, N, C = x.shape + C = C // 3 + qkv = x.reshape([B, N, 3, C // self.head_dim, self.head_dim]).transpose( + [2, 0, 3, 1, 4] + ) + q, k, v = qkv.unbind(0) + attn = q @ k.transpose([0, 1, 3, 2]) * self.scale + attn = F.softmax(attn, -1) + attn = self.attn_drop(attn) + x = attn @ v + x = x.transpose([0, 2, 1]).reshape([B, N, C]) + return x + + +def img2windows(img, H_sp, W_sp): + """ + img: B C H W + """ + B, H, W, C = img.shape + img_reshape = img.reshape([B, H // H_sp, H_sp, W // W_sp, W_sp, C]) + img_perm = img_reshape.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, H_sp * W_sp, C]) + return img_perm + + +def windows2img(img_splits_hw, H_sp, W_sp, H, W): + """ + img_splits_hw: B' H W C + """ + B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp)) + + img = img_splits_hw.reshape([B, H // H_sp, W // W_sp, H_sp, W_sp, -1]) + img = img.transpose([0, 1, 3, 2, 4, 5]).flatten(1, 4) + return img + + +class Block(nn.Layer): + def __init__( + self, + dim, + num_heads, + split_h=4, + split_w=4, + h_num_heads=None, + w_num_heads=None, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + eps=1e-6, + ): + super().__init__() + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.proj = nn.Linear(dim, dim) + self.split_h = split_h + self.split_w = split_w + mlp_hidden_dim = int(dim * mlp_ratio) + self.norm1 = norm_layer(dim, epsilon=eps) + self.h_num_heads = h_num_heads if h_num_heads is not None else num_heads // 2 + self.w_num_heads = w_num_heads if w_num_heads is not None else num_heads // 2 + self.head_dim = dim // num_heads + self.mixer = HWAttention( + head_dim=dim // num_heads, + qk_scale=qk_scale, + attn_drop=attn_drop, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity() + self.norm2 = norm_layer(dim, epsilon=eps) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward(self, x): + B, C, H, W = x.shape + x = x.flatten(2).transpose([0, 2, 1]) + + qkv = self.qkv(x).reshape([B, H, W, 3 * C]) + + x1 = qkv[:, :, :, : 3 * self.h_num_heads * self.head_dim] # b, h, w, 3ch + x2 = qkv[:, :, :, 3 * self.h_num_heads * self.head_dim :] # b, h, w, 3cw + + x1 = self.mixer(img2windows(x1, self.split_h, W)) # b*splith, W, 3ch + x2 = self.mixer(img2windows(x2, H, self.split_w)) # b*splitw, h, 3ch + x1 = windows2img(x1, self.split_h, W, H, W) + x2 = windows2img(x2, H, self.split_w, H, W) + + attened_x = paddle.concat([x1, x2], 2) + attened_x = self.proj(attened_x) + + x = self.norm1(x + self.drop_path(attened_x)) + x = self.norm2(x + self.drop_path(self.mlp(x))) + x = x.transpose([0, 2, 1]).reshape([-1, C, H, W]) + return x + + class SLAHead(nn.Layer): def __init__( self, @@ -143,6 +258,7 @@ def __init__( max_text_length=500, loc_reg_num=4, fc_decay=0.0, + use_attn=False, **kwargs ): """ @@ -158,6 +274,7 @@ def __init__( self.emb = self._char_to_onehot self.num_embeddings = out_channels self.loc_reg_num = loc_reg_num + self.eos = self.num_embeddings - 1 # structure self.structure_attention_cell = AttentionGRUCell( @@ -181,6 +298,21 @@ def __init__( hidden_size, out_channels, weight_attr=weight_attr, bias_attr=bias_attr ), ) + dpr = np.linspace(0, 0.1, 2) + + self.use_attn = use_attn + if use_attn: + layer_list = [ + Block( + in_channels, + num_heads=2, + mlp_ratio=4.0, + qkv_bias=True, + drop_path=dpr[i], + ) + for i in range(2) + ] + self.cross_atten = nn.Sequential(*layer_list) # loc weight_attr1, bias_attr1 = get_para_bias_attr( l2_decay=fc_decay, k=self.hidden_size @@ -207,6 +339,8 @@ def __init__( def forward(self, inputs, targets=None): fea = inputs[-1] batch_size = fea.shape[0] + if self.use_attn: + fea = fea + self.cross_atten(fea) # reshape fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], -1]) fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) @@ -220,15 +354,22 @@ def forward(self, inputs, targets=None): ) structure_preds.stop_gradient = True loc_preds.stop_gradient = True + if self.training and targets is not None: structure = targets[0] - for i in range(self.max_text_length + 1): + max_len = targets[-2].max() + for i in range(max_len + 1): hidden, structure_step, loc_step = self._decode( structure[:, i], fea, hidden ) structure_preds[:, i, :] = structure_step loc_preds[:, i, :] = loc_step + structure_preds = structure_preds[:, : max_len + 1] + loc_preds = loc_preds[:, : max_len + 1] else: + structure_ids = paddle.zeros( + (batch_size, self.max_text_length + 1), dtype=paddle.int64 + ) pre_chars = paddle.zeros(shape=[batch_size], dtype="int32") max_text_length = paddle.to_tensor(self.max_text_length) # for export @@ -238,8 +379,13 @@ def forward(self, inputs, targets=None): pre_chars = structure_step.argmax(axis=1, dtype="int32") structure_preds[:, i, :] = structure_step loc_preds[:, i, :] = loc_step + + structure_ids[:, i] = pre_chars + if (structure_ids == self.eos).any(-1).all(): + break if not self.training: - structure_preds = F.softmax(structure_preds) + structure_preds = F.softmax(structure_preds[:, : i + 1]) + loc_preds = loc_preds[:, : i + 1] return {"structure_probs": structure_preds, "loc_preds": loc_preds} def _decode(self, pre_chars, features, hidden):