Skip to content

Commit 03b3b73

Browse files
committed
support export after save model
1 parent 5b54ac4 commit 03b3b73

File tree

5 files changed

+514
-320
lines changed

5 files changed

+514
-320
lines changed

ppocr/utils/export_model.py

+373
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import yaml
17+
import json
18+
import copy
19+
import paddle
20+
import paddle.nn as nn
21+
from paddle.jit import to_static
22+
23+
from collections import OrderedDict
24+
from argparse import ArgumentParser, RawDescriptionHelpFormatter
25+
from ppocr.modeling.architectures import build_model
26+
from ppocr.postprocess import build_post_process
27+
from ppocr.utils.save_load import load_model
28+
from ppocr.utils.logging import get_logger
29+
30+
31+
def represent_dictionary_order(self, dict_data):
32+
return self.represent_mapping("tag:yaml.org,2002:map", dict_data.items())
33+
34+
35+
def setup_orderdict():
36+
yaml.add_representer(OrderedDict, represent_dictionary_order)
37+
38+
39+
def dump_infer_config(config, path, logger):
40+
setup_orderdict()
41+
infer_cfg = OrderedDict()
42+
if config["Global"].get("hpi_config_path", None):
43+
hpi_config = yaml.safe_load(open(config["Global"]["hpi_config_path"], "r"))
44+
if "RecResizeImg" in config["Eval"]["dataset"]["transforms"]:
45+
dynamic_shapes = [1] + config["Eval"]["dataset"]["RecResizeImg"][
46+
"image_shape"
47+
]
48+
if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None):
49+
hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][
50+
"dynamic_shapes"
51+
]["x"] = [dynamic_shapes for i in range(3)]
52+
hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][
53+
"max_batch_size"
54+
] = 1
55+
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
56+
hpi_config["Hpi"]["backend_config"]["tensorrt"]["dynamic_shapes"][
57+
"x"
58+
] = [dynamic_shapes for i in range(3)]
59+
hpi_config["Hpi"]["backend_config"]["tensorrt"]["max_batch_size"] = 1
60+
else:
61+
if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None):
62+
hpi_config["Hpi"]["supported_backends"]["gpu"].remove("paddle_tensorrt")
63+
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
64+
hpi_config["Hpi"]["supported_backends"]["gpu"].remove("tensorrt")
65+
infer_cfg["Hpi"] = hpi_config["Hpi"]
66+
if config["Global"].get("pdx_model_name", None):
67+
infer_cfg["Global"] = {}
68+
infer_cfg["Global"]["model_name"] = config["Global"]["pdx_model_name"]
69+
70+
infer_cfg["PreProcess"] = {"transform_ops": config["Eval"]["dataset"]["transforms"]}
71+
postprocess = OrderedDict()
72+
for k, v in config["PostProcess"].items():
73+
postprocess[k] = v
74+
75+
if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
76+
tokenizer_file = config["Global"].get("rec_char_dict_path")
77+
if tokenizer_file is not None:
78+
with open(tokenizer_file, encoding="utf-8") as tokenizer_config_handle:
79+
character_dict = json.load(tokenizer_config_handle)
80+
postprocess["character_dict"] = character_dict
81+
else:
82+
if config["Global"].get("character_dict_path") is not None:
83+
with open(config["Global"]["character_dict_path"], encoding="utf-8") as f:
84+
lines = f.readlines()
85+
character_dict = [line.strip("\n") for line in lines]
86+
postprocess["character_dict"] = character_dict
87+
88+
infer_cfg["PostProcess"] = postprocess
89+
90+
with open(path, "w") as f:
91+
yaml.dump(
92+
infer_cfg, f, default_flow_style=False, encoding="utf-8", allow_unicode=True
93+
)
94+
logger.info("Export inference config file to {}".format(os.path.join(path)))
95+
96+
97+
def export_single_model(
98+
model, arch_config, save_path, logger, input_shape=None, quanter=None
99+
):
100+
if arch_config["algorithm"] == "SRN":
101+
max_text_length = arch_config["Head"]["max_text_length"]
102+
other_shape = [
103+
paddle.static.InputSpec(shape=[None, 1, 64, 256], dtype="float32"),
104+
[
105+
paddle.static.InputSpec(shape=[None, 256, 1], dtype="int64"),
106+
paddle.static.InputSpec(
107+
shape=[None, max_text_length, 1], dtype="int64"
108+
),
109+
paddle.static.InputSpec(
110+
shape=[None, 8, max_text_length, max_text_length], dtype="int64"
111+
),
112+
paddle.static.InputSpec(
113+
shape=[None, 8, max_text_length, max_text_length], dtype="int64"
114+
),
115+
],
116+
]
117+
model = to_static(model, input_spec=other_shape)
118+
elif arch_config["algorithm"] == "SAR":
119+
other_shape = [
120+
paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"),
121+
[paddle.static.InputSpec(shape=[None], dtype="float32")],
122+
]
123+
model = to_static(model, input_spec=other_shape)
124+
elif arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
125+
other_shape = [
126+
paddle.static.InputSpec(shape=[None, 3, 48, -1], dtype="float32"),
127+
]
128+
model = to_static(model, input_spec=other_shape)
129+
elif arch_config["algorithm"] in ["SVTR", "CPPD"]:
130+
other_shape = [
131+
paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"),
132+
]
133+
model = to_static(model, input_spec=other_shape)
134+
elif arch_config["algorithm"] == "PREN":
135+
other_shape = [
136+
paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"),
137+
]
138+
model = to_static(model, input_spec=other_shape)
139+
elif arch_config["model_type"] == "sr":
140+
other_shape = [
141+
paddle.static.InputSpec(shape=[None, 3, 16, 64], dtype="float32")
142+
]
143+
model = to_static(model, input_spec=other_shape)
144+
elif arch_config["algorithm"] == "ViTSTR":
145+
other_shape = [
146+
paddle.static.InputSpec(shape=[None, 1, 224, 224], dtype="float32"),
147+
]
148+
model = to_static(model, input_spec=other_shape)
149+
elif arch_config["algorithm"] == "ABINet":
150+
if not input_shape:
151+
input_shape = [3, 32, 128]
152+
other_shape = [
153+
paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"),
154+
]
155+
model = to_static(model, input_spec=other_shape)
156+
elif arch_config["algorithm"] in ["NRTR", "SPIN", "RFL"]:
157+
other_shape = [
158+
paddle.static.InputSpec(shape=[None, 1, 32, 100], dtype="float32"),
159+
]
160+
model = to_static(model, input_spec=other_shape)
161+
elif arch_config["algorithm"] in ["SATRN"]:
162+
other_shape = [
163+
paddle.static.InputSpec(shape=[None, 3, 32, 100], dtype="float32"),
164+
]
165+
model = to_static(model, input_spec=other_shape)
166+
elif arch_config["algorithm"] == "VisionLAN":
167+
other_shape = [
168+
paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"),
169+
]
170+
model = to_static(model, input_spec=other_shape)
171+
elif arch_config["algorithm"] == "RobustScanner":
172+
max_text_length = arch_config["Head"]["max_text_length"]
173+
other_shape = [
174+
paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"),
175+
[
176+
paddle.static.InputSpec(
177+
shape=[
178+
None,
179+
],
180+
dtype="float32",
181+
),
182+
paddle.static.InputSpec(shape=[None, max_text_length], dtype="int64"),
183+
],
184+
]
185+
model = to_static(model, input_spec=other_shape)
186+
elif arch_config["algorithm"] == "CAN":
187+
other_shape = [
188+
[
189+
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
190+
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
191+
paddle.static.InputSpec(
192+
shape=[None, arch_config["Head"]["max_text_length"]], dtype="int64"
193+
),
194+
]
195+
]
196+
model = to_static(model, input_spec=other_shape)
197+
elif arch_config["algorithm"] == "LaTeXOCR":
198+
other_shape = [
199+
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
200+
]
201+
model = to_static(model, input_spec=other_shape)
202+
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
203+
input_spec = [
204+
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # input_ids
205+
paddle.static.InputSpec(shape=[None, 512, 4], dtype="int64"), # bbox
206+
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # attention_mask
207+
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # token_type_ids
208+
paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype="int64"), # image
209+
]
210+
if "Re" in arch_config["Backbone"]["name"]:
211+
input_spec.extend(
212+
[
213+
paddle.static.InputSpec(
214+
shape=[None, 512, 3], dtype="int64"
215+
), # entities
216+
paddle.static.InputSpec(
217+
shape=[None, None, 2], dtype="int64"
218+
), # relations
219+
]
220+
)
221+
if model.backbone.use_visual_backbone is False:
222+
input_spec.pop(4)
223+
model = to_static(model, input_spec=[input_spec])
224+
else:
225+
infer_shape = [3, -1, -1]
226+
if arch_config["model_type"] == "rec":
227+
infer_shape = [3, 32, -1] # for rec model, H must be 32
228+
if (
229+
"Transform" in arch_config
230+
and arch_config["Transform"] is not None
231+
and arch_config["Transform"]["name"] == "TPS"
232+
):
233+
logger.info(
234+
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
235+
)
236+
infer_shape[-1] = 100
237+
elif arch_config["model_type"] == "table":
238+
infer_shape = [3, 488, 488]
239+
if arch_config["algorithm"] == "TableMaster":
240+
infer_shape = [3, 480, 480]
241+
if arch_config["algorithm"] == "SLANet":
242+
infer_shape = [3, -1, -1]
243+
model = to_static(
244+
model,
245+
input_spec=[
246+
paddle.static.InputSpec(shape=[None] + infer_shape, dtype="float32")
247+
],
248+
)
249+
250+
if (
251+
arch_config["model_type"] != "sr"
252+
and arch_config["Backbone"]["name"] == "PPLCNetV3"
253+
):
254+
# for rep lcnetv3
255+
for layer in model.sublayers():
256+
if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
257+
layer.rep()
258+
259+
if quanter is None:
260+
paddle.jit.save(model, save_path)
261+
else:
262+
quanter.save_quantized_model(model, save_path)
263+
logger.info("inference model is saved to {}".format(save_path))
264+
return
265+
266+
267+
def export(config, base_model=None, save_path=None):
268+
if paddle.distributed.get_rank() != 0:
269+
return
270+
logger = get_logger()
271+
# build post process
272+
273+
post_process_class = build_post_process(config["PostProcess"], config["Global"])
274+
275+
# build model
276+
# for rec algorithm
277+
if hasattr(post_process_class, "character"):
278+
char_num = len(getattr(post_process_class, "character"))
279+
if config["Architecture"]["algorithm"] in [
280+
"Distillation",
281+
]: # distillation model
282+
for key in config["Architecture"]["Models"]:
283+
if (
284+
config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
285+
): # multi head
286+
out_channels_list = {}
287+
if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
288+
char_num = char_num - 2
289+
if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
290+
char_num = char_num - 3
291+
out_channels_list["CTCLabelDecode"] = char_num
292+
out_channels_list["SARLabelDecode"] = char_num + 2
293+
out_channels_list["NRTRLabelDecode"] = char_num + 3
294+
config["Architecture"]["Models"][key]["Head"][
295+
"out_channels_list"
296+
] = out_channels_list
297+
else:
298+
config["Architecture"]["Models"][key]["Head"][
299+
"out_channels"
300+
] = char_num
301+
# just one final tensor needs to exported for inference
302+
config["Architecture"]["Models"][key]["return_all_feats"] = False
303+
elif config["Architecture"]["Head"]["name"] == "MultiHead": # multi head
304+
out_channels_list = {}
305+
char_num = len(getattr(post_process_class, "character"))
306+
if config["PostProcess"]["name"] == "SARLabelDecode":
307+
char_num = char_num - 2
308+
if config["PostProcess"]["name"] == "NRTRLabelDecode":
309+
char_num = char_num - 3
310+
out_channels_list["CTCLabelDecode"] = char_num
311+
out_channels_list["SARLabelDecode"] = char_num + 2
312+
out_channels_list["NRTRLabelDecode"] = char_num + 3
313+
config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
314+
else: # base rec model
315+
config["Architecture"]["Head"]["out_channels"] = char_num
316+
317+
# for sr algorithm
318+
if config["Architecture"]["model_type"] == "sr":
319+
config["Architecture"]["Transform"]["infer_mode"] = True
320+
321+
# for latexocr algorithm
322+
if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
323+
config["Architecture"]["Backbone"]["is_predict"] = True
324+
config["Architecture"]["Backbone"]["is_export"] = True
325+
config["Architecture"]["Head"]["is_export"] = True
326+
if base_model is not None:
327+
model = base_model
328+
if model.__class__.__name__ == "DataParallel":
329+
model = model._layers
330+
model = copy.deepcopy(model)
331+
else:
332+
model = build_model(config["Architecture"])
333+
load_model(config, model, model_type=config["Architecture"]["model_type"])
334+
model.eval()
335+
336+
if not save_path:
337+
save_path = config["Global"]["save_inference_dir"]
338+
yaml_path = os.path.join(save_path, "inference.yml")
339+
340+
arch_config = config["Architecture"]
341+
342+
if (
343+
arch_config["algorithm"] in ["SVTR", "CPPD"]
344+
and arch_config["Head"]["name"] != "MultiHead"
345+
):
346+
input_shape = config["Eval"]["dataset"]["transforms"][-2]["SVTRRecResizeImg"][
347+
"image_shape"
348+
]
349+
elif arch_config["algorithm"].lower() == "ABINet".lower():
350+
rec_rs = [
351+
c
352+
for c in config["Eval"]["dataset"]["transforms"]
353+
if "ABINetRecResizeImg" in c
354+
]
355+
input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None
356+
else:
357+
input_shape = None
358+
359+
if arch_config["algorithm"] in [
360+
"Distillation",
361+
]: # distillation model
362+
archs = list(arch_config["Models"].values())
363+
for idx, name in enumerate(model.model_name_list):
364+
sub_model_save_path = os.path.join(save_path, name, "inference")
365+
export_single_model(
366+
model.model_list[idx], archs[idx], sub_model_save_path, logger
367+
)
368+
else:
369+
save_path = os.path.join(save_path, "inference")
370+
export_single_model(
371+
model, arch_config, save_path, logger, input_shape=input_shape
372+
)
373+
dump_infer_config(config, yaml_path, logger)

0 commit comments

Comments
 (0)