Skip to content

Commit a836921

Browse files
authored
add ocr-det v5 model (#15123)
* add ocr detV5 model * add ocr detV5 pretrained model link
1 parent 0caa3e9 commit a836921

File tree

4 files changed

+355
-3
lines changed

4 files changed

+355
-3
lines changed
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
Global:
2+
model_name: PP-OCRv5_mobile_det # To use static model for inference.
3+
debug: false
4+
use_gpu: true
5+
epoch_num: &epoch_num 500
6+
log_smooth_window: 20
7+
print_batch_step: 100
8+
save_model_dir: ./output/PP-OCRv5_mobile_det
9+
save_epoch_step: 10
10+
eval_batch_step:
11+
- 0
12+
- 1500
13+
cal_metric_during_train: false
14+
checkpoints:
15+
pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/PPLCNetV3_x0_75_ocr_det.pdparams
16+
save_inference_dir: null
17+
use_visualdl: false
18+
infer_img: doc/imgs_en/img_10.jpg
19+
save_res_path: ./checkpoints/det_db/predicts_db.txt
20+
d2s_train_image_shape: [3, 640, 640]
21+
distributed: true
22+
23+
Architecture:
24+
model_type: det
25+
algorithm: DB
26+
Transform: null
27+
Backbone:
28+
name: PPLCNetV3
29+
scale: 0.75
30+
det: True
31+
Neck:
32+
name: RSEFPN
33+
out_channels: 96
34+
shortcut: True
35+
Head:
36+
name: DBHead
37+
k: 50
38+
fix_nan: True
39+
40+
Loss:
41+
name: DBLoss
42+
balance_loss: true
43+
main_loss_type: DiceLoss
44+
alpha: 5
45+
beta: 10
46+
ohem_ratio: 3
47+
48+
Optimizer:
49+
name: Adam
50+
beta1: 0.9
51+
beta2: 0.999
52+
lr:
53+
name: Cosine
54+
learning_rate: 0.001 #(8*8c)
55+
warmup_epoch: 2
56+
regularizer:
57+
name: L2
58+
factor: 5.0e-05
59+
60+
PostProcess:
61+
name: DBPostProcess
62+
thresh: 0.3
63+
box_thresh: 0.6
64+
max_candidates: 1000
65+
unclip_ratio: 1.5
66+
67+
Metric:
68+
name: DetMetric
69+
main_indicator: hmean
70+
71+
Train:
72+
dataset:
73+
name: SimpleDataSet
74+
data_dir: ./train_data/icdar2015/text_localization/
75+
label_file_list:
76+
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
77+
ratio_list: [1.0]
78+
transforms:
79+
- DecodeImage:
80+
img_mode: BGR
81+
channel_first: false
82+
- DetLabelEncode: null
83+
- CopyPaste: null
84+
- IaaAugment:
85+
augmenter_args:
86+
- type: Fliplr
87+
args:
88+
p: 0.5
89+
- type: Affine
90+
args:
91+
rotate:
92+
- -10
93+
- 10
94+
- type: Resize
95+
args:
96+
size:
97+
- 0.5
98+
- 3
99+
- EastRandomCropData:
100+
size:
101+
- 640
102+
- 640
103+
max_tries: 50
104+
keep_ratio: true
105+
- MakeBorderMap:
106+
shrink_ratio: 0.4
107+
thresh_min: 0.3
108+
thresh_max: 0.7
109+
total_epoch: *epoch_num
110+
- MakeShrinkMap:
111+
shrink_ratio: 0.4
112+
min_text_size: 8
113+
total_epoch: *epoch_num
114+
- NormalizeImage:
115+
scale: 1./255.
116+
mean:
117+
- 0.485
118+
- 0.456
119+
- 0.406
120+
std:
121+
- 0.229
122+
- 0.224
123+
- 0.225
124+
order: hwc
125+
- ToCHWImage: null
126+
- KeepKeys:
127+
keep_keys:
128+
- image
129+
- threshold_map
130+
- threshold_mask
131+
- shrink_map
132+
- shrink_mask
133+
loader:
134+
shuffle: true
135+
drop_last: false
136+
batch_size_per_card: 8
137+
num_workers: 8
138+
139+
Eval:
140+
dataset:
141+
name: SimpleDataSet
142+
data_dir: ./train_data/icdar2015/text_localization/
143+
label_file_list:
144+
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
145+
transforms:
146+
- DecodeImage:
147+
img_mode: BGR
148+
channel_first: false
149+
- DetLabelEncode: null
150+
- DetResizeForTest:
151+
- NormalizeImage:
152+
scale: 1./255.
153+
mean:
154+
- 0.485
155+
- 0.456
156+
- 0.406
157+
std:
158+
- 0.229
159+
- 0.224
160+
- 0.225
161+
order: hwc
162+
- ToCHWImage: null
163+
- KeepKeys:
164+
keep_keys:
165+
- image
166+
- shape
167+
- polys
168+
- ignore_tags
169+
loader:
170+
shuffle: false
171+
drop_last: false
172+
batch_size_per_card: 1
173+
num_workers: 2
174+
profiler_options: null
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
Global:
2+
model_name: PP-OCRv5_server_det # To use static model for inference.
3+
debug: false
4+
use_gpu: true
5+
epoch_num: &epoch_num 500
6+
log_smooth_window: 20
7+
print_batch_step: 10
8+
save_model_dir: ./output/PP-OCRv5_server_det
9+
save_epoch_step: 10
10+
eval_batch_step:
11+
- 0
12+
- 1500
13+
cal_metric_during_train: false
14+
checkpoints:
15+
pretrained_model: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PPHGNetV2_B4_ocr_det.pdparams
16+
save_inference_dir: null
17+
use_visualdl: false
18+
infer_img: doc/imgs_en/img_10.jpg
19+
save_res_path: ./checkpoints/det_db/predicts_db.txt
20+
distributed: true
21+
22+
Architecture:
23+
model_type: det
24+
algorithm: DB
25+
Transform: null
26+
Backbone:
27+
name: PPHGNetV2_B4
28+
det: True
29+
Neck:
30+
name: LKPAN
31+
out_channels: 256
32+
intracl: true
33+
Head:
34+
name: PFHeadLocal
35+
k: 50
36+
mode: "large"
37+
38+
39+
Loss:
40+
name: DBLoss
41+
balance_loss: true
42+
main_loss_type: DiceLoss
43+
alpha: 5
44+
beta: 10
45+
ohem_ratio: 3
46+
47+
Optimizer:
48+
name: Adam
49+
beta1: 0.9
50+
beta2: 0.999
51+
lr:
52+
name: Cosine
53+
learning_rate: 0.001 #(8*8c)
54+
warmup_epoch: 2
55+
regularizer:
56+
name: L2
57+
factor: 1e-6
58+
59+
PostProcess:
60+
name: DBPostProcess
61+
thresh: 0.3
62+
box_thresh: 0.6
63+
max_candidates: 1000
64+
unclip_ratio: 1.5
65+
66+
Metric:
67+
name: DetMetric
68+
main_indicator: hmean
69+
70+
Train:
71+
dataset:
72+
name: SimpleDataSet
73+
data_dir: ./train_data/icdar2015/text_localization/
74+
label_file_list:
75+
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
76+
ratio_list: [1.0]
77+
transforms:
78+
- DecodeImage:
79+
img_mode: BGR
80+
channel_first: false
81+
- DetLabelEncode: null
82+
- CopyPaste: null
83+
- IaaAugment:
84+
augmenter_args:
85+
- type: Fliplr
86+
args:
87+
p: 0.5
88+
- type: Affine
89+
args:
90+
rotate:
91+
- -10
92+
- 10
93+
- type: Resize
94+
args:
95+
size:
96+
- 0.5
97+
- 3
98+
- EastRandomCropData:
99+
size:
100+
- 640
101+
- 640
102+
max_tries: 50
103+
keep_ratio: true
104+
- MakeBorderMap:
105+
shrink_ratio: 0.4
106+
thresh_min: 0.3
107+
thresh_max: 0.7
108+
total_epoch: *epoch_num
109+
- MakeShrinkMap:
110+
shrink_ratio: 0.4
111+
min_text_size: 8
112+
total_epoch: *epoch_num
113+
- NormalizeImage:
114+
scale: 1./255.
115+
mean:
116+
- 0.485
117+
- 0.456
118+
- 0.406
119+
std:
120+
- 0.229
121+
- 0.224
122+
- 0.225
123+
order: hwc
124+
- ToCHWImage: null
125+
- KeepKeys:
126+
keep_keys:
127+
- image
128+
- threshold_map
129+
- threshold_mask
130+
- shrink_map
131+
- shrink_mask
132+
loader:
133+
shuffle: true
134+
drop_last: false
135+
batch_size_per_card: 8
136+
num_workers: 8
137+
138+
Eval:
139+
dataset:
140+
name: SimpleDataSet
141+
data_dir: ./train_data/icdar2015/text_localization/
142+
label_file_list:
143+
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
144+
transforms:
145+
transforms:
146+
- DecodeImage:
147+
img_mode: BGR
148+
channel_first: false
149+
- DetLabelEncode: null
150+
- DetResizeForTest:
151+
- NormalizeImage:
152+
scale: 1./255.
153+
mean:
154+
- 0.485
155+
- 0.456
156+
- 0.406
157+
std:
158+
- 0.229
159+
- 0.224
160+
- 0.225
161+
order: hwc
162+
- ToCHWImage: null
163+
- KeepKeys:
164+
keep_keys:
165+
- image
166+
- shape
167+
- polys
168+
- ignore_tags
169+
loader:
170+
shuffle: false
171+
drop_last: false
172+
batch_size_per_card: 1
173+
num_workers: 2
174+
profiler_options: null

ppocr/modeling/backbones/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def build_backbone(config, model_type):
2828
from .det_pp_lcnet_v2 import PPLCNetV2_base
2929
from .rec_repvit import RepSVTR_det
3030
from .rec_vary_vit import Vary_VIT_B
31+
from .rec_pphgnetv2 import PPHGNetV2_B4
3132

3233
support_dict = [
3334
"MobileNetV3",
@@ -40,6 +41,7 @@ def build_backbone(config, model_type):
4041
"PPLCNetV2_base",
4142
"RepSVTR_det",
4243
"Vary_VIT_B",
44+
"PPHGNetV2_B4",
4345
]
4446
if model_type == "table":
4547
from .table_master_resnet import TableResNetExtra

ppocr/modeling/backbones/rec_pphgnetv2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,9 +1381,11 @@ def __init__(
13811381
self.dropout = nn.Dropout(p=dropout_prob, mode="downscale_in_infer")
13821382

13831383
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
1384-
self.fc = nn.Linear(
1385-
self.class_expand if self.use_last_conv else out_channels, self.class_num
1386-
)
1384+
if not self.det:
1385+
self.fc = nn.Linear(
1386+
self.class_expand if self.use_last_conv else out_channels,
1387+
self.class_num,
1388+
)
13871389

13881390
self._init_weights()
13891391

0 commit comments

Comments
 (0)