Skip to content

Commit 4adde06

Browse files
committed
support export after save model
1 parent 5b54ac4 commit 4adde06

File tree

5 files changed

+528
-321
lines changed

5 files changed

+528
-321
lines changed

ppocr/utils/export_model.py

+381
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,381 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import yaml
17+
import json
18+
import copy
19+
import paddle
20+
import paddle.nn as nn
21+
from paddle.jit import to_static
22+
23+
from collections import OrderedDict
24+
from argparse import ArgumentParser, RawDescriptionHelpFormatter
25+
from ppocr.modeling.architectures import build_model
26+
from ppocr.postprocess import build_post_process
27+
from ppocr.utils.save_load import load_model
28+
from ppocr.utils.logging import get_logger
29+
30+
31+
def represent_dictionary_order(self, dict_data):
32+
return self.represent_mapping("tag:yaml.org,2002:map", dict_data.items())
33+
34+
35+
def setup_orderdict():
36+
yaml.add_representer(OrderedDict, represent_dictionary_order)
37+
38+
39+
def dump_infer_config(config, path, logger):
40+
setup_orderdict()
41+
infer_cfg = OrderedDict()
42+
if config["Global"].get("hpi_config_path", None):
43+
hpi_config = yaml.safe_load(open(config["Global"]["hpi_config_path"], "r"))
44+
rec_resize_img_dict = next(
45+
(
46+
item
47+
for item in config["Eval"]["dataset"]["transforms"]
48+
if "RecResizeImg" in item
49+
),
50+
None,
51+
)
52+
if rec_resize_img_dict:
53+
dynamic_shapes = [1] + rec_resize_img_dict["RecResizeImg"]["image_shape"]
54+
if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None):
55+
hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][
56+
"dynamic_shapes"
57+
]["x"] = [dynamic_shapes for i in range(3)]
58+
hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][
59+
"max_batch_size"
60+
] = 1
61+
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
62+
hpi_config["Hpi"]["backend_config"]["tensorrt"]["dynamic_shapes"][
63+
"x"
64+
] = [dynamic_shapes for i in range(3)]
65+
hpi_config["Hpi"]["backend_config"]["tensorrt"]["max_batch_size"] = 1
66+
else:
67+
if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None):
68+
hpi_config["Hpi"]["supported_backends"]["gpu"].remove("paddle_tensorrt")
69+
del hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"]
70+
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
71+
hpi_config["Hpi"]["supported_backends"]["gpu"].remove("tensorrt")
72+
del hpi_config["Hpi"]["backend_config"]["tensorrt"]
73+
infer_cfg["Hpi"] = hpi_config["Hpi"]
74+
if config["Global"].get("pdx_model_name", None):
75+
infer_cfg["Global"] = {}
76+
infer_cfg["Global"]["model_name"] = config["Global"]["pdx_model_name"]
77+
78+
infer_cfg["PreProcess"] = {"transform_ops": config["Eval"]["dataset"]["transforms"]}
79+
postprocess = OrderedDict()
80+
for k, v in config["PostProcess"].items():
81+
postprocess[k] = v
82+
83+
if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
84+
tokenizer_file = config["Global"].get("rec_char_dict_path")
85+
if tokenizer_file is not None:
86+
with open(tokenizer_file, encoding="utf-8") as tokenizer_config_handle:
87+
character_dict = json.load(tokenizer_config_handle)
88+
postprocess["character_dict"] = character_dict
89+
else:
90+
if config["Global"].get("character_dict_path") is not None:
91+
with open(config["Global"]["character_dict_path"], encoding="utf-8") as f:
92+
lines = f.readlines()
93+
character_dict = [line.strip("\n") for line in lines]
94+
postprocess["character_dict"] = character_dict
95+
96+
infer_cfg["PostProcess"] = postprocess
97+
98+
with open(path, "w") as f:
99+
yaml.dump(
100+
infer_cfg, f, default_flow_style=False, encoding="utf-8", allow_unicode=True
101+
)
102+
logger.info("Export inference config file to {}".format(os.path.join(path)))
103+
104+
105+
def export_single_model(
106+
model, arch_config, save_path, logger, input_shape=None, quanter=None
107+
):
108+
if arch_config["algorithm"] == "SRN":
109+
max_text_length = arch_config["Head"]["max_text_length"]
110+
other_shape = [
111+
paddle.static.InputSpec(shape=[None, 1, 64, 256], dtype="float32"),
112+
[
113+
paddle.static.InputSpec(shape=[None, 256, 1], dtype="int64"),
114+
paddle.static.InputSpec(
115+
shape=[None, max_text_length, 1], dtype="int64"
116+
),
117+
paddle.static.InputSpec(
118+
shape=[None, 8, max_text_length, max_text_length], dtype="int64"
119+
),
120+
paddle.static.InputSpec(
121+
shape=[None, 8, max_text_length, max_text_length], dtype="int64"
122+
),
123+
],
124+
]
125+
model = to_static(model, input_spec=other_shape)
126+
elif arch_config["algorithm"] == "SAR":
127+
other_shape = [
128+
paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"),
129+
[paddle.static.InputSpec(shape=[None], dtype="float32")],
130+
]
131+
model = to_static(model, input_spec=other_shape)
132+
elif arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
133+
other_shape = [
134+
paddle.static.InputSpec(shape=[None, 3, 48, -1], dtype="float32"),
135+
]
136+
model = to_static(model, input_spec=other_shape)
137+
elif arch_config["algorithm"] in ["SVTR", "CPPD"]:
138+
other_shape = [
139+
paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"),
140+
]
141+
model = to_static(model, input_spec=other_shape)
142+
elif arch_config["algorithm"] == "PREN":
143+
other_shape = [
144+
paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"),
145+
]
146+
model = to_static(model, input_spec=other_shape)
147+
elif arch_config["model_type"] == "sr":
148+
other_shape = [
149+
paddle.static.InputSpec(shape=[None, 3, 16, 64], dtype="float32")
150+
]
151+
model = to_static(model, input_spec=other_shape)
152+
elif arch_config["algorithm"] == "ViTSTR":
153+
other_shape = [
154+
paddle.static.InputSpec(shape=[None, 1, 224, 224], dtype="float32"),
155+
]
156+
model = to_static(model, input_spec=other_shape)
157+
elif arch_config["algorithm"] == "ABINet":
158+
if not input_shape:
159+
input_shape = [3, 32, 128]
160+
other_shape = [
161+
paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"),
162+
]
163+
model = to_static(model, input_spec=other_shape)
164+
elif arch_config["algorithm"] in ["NRTR", "SPIN", "RFL"]:
165+
other_shape = [
166+
paddle.static.InputSpec(shape=[None, 1, 32, 100], dtype="float32"),
167+
]
168+
model = to_static(model, input_spec=other_shape)
169+
elif arch_config["algorithm"] in ["SATRN"]:
170+
other_shape = [
171+
paddle.static.InputSpec(shape=[None, 3, 32, 100], dtype="float32"),
172+
]
173+
model = to_static(model, input_spec=other_shape)
174+
elif arch_config["algorithm"] == "VisionLAN":
175+
other_shape = [
176+
paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"),
177+
]
178+
model = to_static(model, input_spec=other_shape)
179+
elif arch_config["algorithm"] == "RobustScanner":
180+
max_text_length = arch_config["Head"]["max_text_length"]
181+
other_shape = [
182+
paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"),
183+
[
184+
paddle.static.InputSpec(
185+
shape=[
186+
None,
187+
],
188+
dtype="float32",
189+
),
190+
paddle.static.InputSpec(shape=[None, max_text_length], dtype="int64"),
191+
],
192+
]
193+
model = to_static(model, input_spec=other_shape)
194+
elif arch_config["algorithm"] == "CAN":
195+
other_shape = [
196+
[
197+
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
198+
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
199+
paddle.static.InputSpec(
200+
shape=[None, arch_config["Head"]["max_text_length"]], dtype="int64"
201+
),
202+
]
203+
]
204+
model = to_static(model, input_spec=other_shape)
205+
elif arch_config["algorithm"] == "LaTeXOCR":
206+
other_shape = [
207+
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
208+
]
209+
model = to_static(model, input_spec=other_shape)
210+
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
211+
input_spec = [
212+
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # input_ids
213+
paddle.static.InputSpec(shape=[None, 512, 4], dtype="int64"), # bbox
214+
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # attention_mask
215+
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # token_type_ids
216+
paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype="int64"), # image
217+
]
218+
if "Re" in arch_config["Backbone"]["name"]:
219+
input_spec.extend(
220+
[
221+
paddle.static.InputSpec(
222+
shape=[None, 512, 3], dtype="int64"
223+
), # entities
224+
paddle.static.InputSpec(
225+
shape=[None, None, 2], dtype="int64"
226+
), # relations
227+
]
228+
)
229+
if model.backbone.use_visual_backbone is False:
230+
input_spec.pop(4)
231+
model = to_static(model, input_spec=[input_spec])
232+
else:
233+
infer_shape = [3, -1, -1]
234+
if arch_config["model_type"] == "rec":
235+
infer_shape = [3, 32, -1] # for rec model, H must be 32
236+
if (
237+
"Transform" in arch_config
238+
and arch_config["Transform"] is not None
239+
and arch_config["Transform"]["name"] == "TPS"
240+
):
241+
logger.info(
242+
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
243+
)
244+
infer_shape[-1] = 100
245+
elif arch_config["model_type"] == "table":
246+
infer_shape = [3, 488, 488]
247+
if arch_config["algorithm"] == "TableMaster":
248+
infer_shape = [3, 480, 480]
249+
if arch_config["algorithm"] == "SLANet":
250+
infer_shape = [3, -1, -1]
251+
model = to_static(
252+
model,
253+
input_spec=[
254+
paddle.static.InputSpec(shape=[None] + infer_shape, dtype="float32")
255+
],
256+
)
257+
258+
if (
259+
arch_config["model_type"] != "sr"
260+
and arch_config["Backbone"]["name"] == "PPLCNetV3"
261+
):
262+
# for rep lcnetv3
263+
for layer in model.sublayers():
264+
if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
265+
layer.rep()
266+
267+
if quanter is None:
268+
paddle.jit.save(model, save_path)
269+
else:
270+
quanter.save_quantized_model(model, save_path)
271+
logger.info("inference model is saved to {}".format(save_path))
272+
return
273+
274+
275+
def export(config, base_model=None, save_path=None):
276+
if paddle.distributed.get_rank() != 0:
277+
return
278+
logger = get_logger()
279+
# build post process
280+
post_process_class = build_post_process(config["PostProcess"], config["Global"])
281+
282+
# build model
283+
# for rec algorithm
284+
if hasattr(post_process_class, "character"):
285+
char_num = len(getattr(post_process_class, "character"))
286+
if config["Architecture"]["algorithm"] in [
287+
"Distillation",
288+
]: # distillation model
289+
for key in config["Architecture"]["Models"]:
290+
if (
291+
config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
292+
): # multi head
293+
out_channels_list = {}
294+
if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
295+
char_num = char_num - 2
296+
if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
297+
char_num = char_num - 3
298+
out_channels_list["CTCLabelDecode"] = char_num
299+
out_channels_list["SARLabelDecode"] = char_num + 2
300+
out_channels_list["NRTRLabelDecode"] = char_num + 3
301+
config["Architecture"]["Models"][key]["Head"][
302+
"out_channels_list"
303+
] = out_channels_list
304+
else:
305+
config["Architecture"]["Models"][key]["Head"][
306+
"out_channels"
307+
] = char_num
308+
# just one final tensor needs to exported for inference
309+
config["Architecture"]["Models"][key]["return_all_feats"] = False
310+
elif config["Architecture"]["Head"]["name"] == "MultiHead": # multi head
311+
out_channels_list = {}
312+
char_num = len(getattr(post_process_class, "character"))
313+
if config["PostProcess"]["name"] == "SARLabelDecode":
314+
char_num = char_num - 2
315+
if config["PostProcess"]["name"] == "NRTRLabelDecode":
316+
char_num = char_num - 3
317+
out_channels_list["CTCLabelDecode"] = char_num
318+
out_channels_list["SARLabelDecode"] = char_num + 2
319+
out_channels_list["NRTRLabelDecode"] = char_num + 3
320+
config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
321+
else: # base rec model
322+
config["Architecture"]["Head"]["out_channels"] = char_num
323+
324+
# for sr algorithm
325+
if config["Architecture"]["model_type"] == "sr":
326+
config["Architecture"]["Transform"]["infer_mode"] = True
327+
328+
# for latexocr algorithm
329+
if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
330+
config["Architecture"]["Backbone"]["is_predict"] = True
331+
config["Architecture"]["Backbone"]["is_export"] = True
332+
config["Architecture"]["Head"]["is_export"] = True
333+
if base_model is not None:
334+
model = base_model
335+
if isinstance(model, paddle.DataParallel):
336+
model = copy.deepcopy(model._layers)
337+
else:
338+
model = copy.deepcopy(model)
339+
else:
340+
model = build_model(config["Architecture"])
341+
load_model(config, model, model_type=config["Architecture"]["model_type"])
342+
model.eval()
343+
344+
if not save_path:
345+
save_path = config["Global"]["save_inference_dir"]
346+
yaml_path = os.path.join(save_path, "inference.yml")
347+
348+
arch_config = config["Architecture"]
349+
350+
if (
351+
arch_config["algorithm"] in ["SVTR", "CPPD"]
352+
and arch_config["Head"]["name"] != "MultiHead"
353+
):
354+
input_shape = config["Eval"]["dataset"]["transforms"][-2]["SVTRRecResizeImg"][
355+
"image_shape"
356+
]
357+
elif arch_config["algorithm"].lower() == "ABINet".lower():
358+
rec_rs = [
359+
c
360+
for c in config["Eval"]["dataset"]["transforms"]
361+
if "ABINetRecResizeImg" in c
362+
]
363+
input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None
364+
else:
365+
input_shape = None
366+
367+
if arch_config["algorithm"] in [
368+
"Distillation",
369+
]: # distillation model
370+
archs = list(arch_config["Models"].values())
371+
for idx, name in enumerate(model.model_name_list):
372+
sub_model_save_path = os.path.join(save_path, name, "inference")
373+
export_single_model(
374+
model.model_list[idx], archs[idx], sub_model_save_path, logger
375+
)
376+
else:
377+
save_path = os.path.join(save_path, "inference")
378+
export_single_model(
379+
model, arch_config, save_path, logger, input_shape=input_shape
380+
)
381+
dump_infer_config(config, yaml_path, logger)

0 commit comments

Comments
 (0)