Skip to content

Commit a8b4311

Browse files
committed
support export with pir and no pir
1 parent 0accd26 commit a8b4311

File tree

1 file changed

+56
-10
lines changed

1 file changed

+56
-10
lines changed

ppocr/utils/export_model.py

+56-10
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
import yaml
1717
import json
1818
import copy
19+
import shutil
1920
import paddle
2021
import paddle.nn as nn
2122
from paddle.jit import to_static
2223

2324
from collections import OrderedDict
25+
from packaging import version
2426
from argparse import ArgumentParser, RawDescriptionHelpFormatter
2527
from ppocr.modeling.architectures import build_model
2628
from ppocr.postprocess import build_post_process
@@ -39,21 +41,23 @@ def setup_orderdict():
3941
def dump_infer_config(config, path, logger):
4042
setup_orderdict()
4143
infer_cfg = OrderedDict()
44+
if not os.path.exists(os.path.dirname(path)):
45+
os.makedirs(os.path.dirname(path))
4246
if config["Global"].get("pdx_model_name", None):
4347
infer_cfg["Global"] = {"model_name": config["Global"]["pdx_model_name"]}
4448
if config["Global"].get("uniform_output_enabled", None):
4549
arch_config = config["Architecture"]
4650
if arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
4751
common_dynamic_shapes = {
48-
"x": [[1, 3, 48, 320], [1, 3, 48, 320], [8, 3, 48, 320]]
52+
"x": [[1, 3, 24, 160], [1, 3, 48, 320], [8, 3, 96, 640]]
4953
}
5054
elif arch_config["model_type"] == "det":
5155
common_dynamic_shapes = {
5256
"x": [[1, 3, 160, 160], [1, 3, 160, 160], [1, 3, 1280, 1280]]
5357
}
5458
elif arch_config["algorithm"] == "SLANet":
5559
common_dynamic_shapes = {
56-
"x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 192, 672]]
60+
"x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 488, 488]]
5761
}
5862
elif arch_config["algorithm"] == "LaTeXOCR":
5963
common_dynamic_shapes = {
@@ -101,9 +105,7 @@ def dump_infer_config(config, path, logger):
101105
logger.info("Export inference config file to {}".format(os.path.join(path)))
102106

103107

104-
def export_single_model(
105-
model, arch_config, save_path, logger, input_shape=None, quanter=None
106-
):
108+
def dynamic_to_static(model, arch_config, logger, input_shape=None):
107109
if arch_config["algorithm"] == "SRN":
108110
max_text_length = arch_config["Head"]["max_text_length"]
109111
other_shape = [
@@ -262,9 +264,50 @@ def export_single_model(
262264
for layer in model.sublayers():
263265
if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
264266
layer.rep()
267+
return model
268+
269+
270+
def export_single_model(
271+
model, arch_config, save_path, logger, yaml_path, input_shape=None, quanter=None
272+
):
273+
274+
model = dynamic_to_static(model, arch_config, logger, input_shape)
265275

266276
if quanter is None:
267-
paddle.jit.save(model, save_path)
277+
paddle_version = version.parse(paddle.__version__)
278+
if paddle_version >= version.parse(
279+
"3.0.0b2"
280+
) or paddle_version == version.parse("0.0.0"):
281+
save_path = os.path.dirname(save_path)
282+
for enable_pir in [True, False]:
283+
if not enable_pir:
284+
save_path_no_pir = os.path.join(save_path, "no_pir", "inference")
285+
model.forward.rollback()
286+
with paddle.pir_utils.OldIrGuard():
287+
model = dynamic_to_static(
288+
model, arch_config, logger, input_shape
289+
)
290+
paddle.jit.save(model, save_path_no_pir)
291+
shutil.copy(
292+
yaml_path,
293+
os.path.join(
294+
os.path.join(save_path, "no_pir"),
295+
os.path.basename(yaml_path),
296+
),
297+
)
298+
else:
299+
save_path_pir = os.path.join(save_path, "pir", "inference")
300+
paddle.jit.save(model, save_path_pir)
301+
shutil.copy(
302+
yaml_path,
303+
os.path.join(
304+
os.path.join(save_path, "pir"), os.path.basename(yaml_path)
305+
),
306+
)
307+
if os.path.exists(yaml_path):
308+
os.remove(yaml_path)
309+
else:
310+
paddle.jit.save(model, save_path)
268311
else:
269312
quanter.save_quantized_model(model, save_path)
270313
logger.info("inference model is saved to {}".format(save_path))
@@ -362,19 +405,22 @@ def export(config, base_model=None, save_path=None):
362405
input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None
363406
else:
364407
input_shape = None
365-
408+
dump_infer_config(config, yaml_path, logger)
366409
if arch_config["algorithm"] in [
367410
"Distillation",
368411
]: # distillation model
369412
archs = list(arch_config["Models"].values())
370413
for idx, name in enumerate(model.model_name_list):
371414
sub_model_save_path = os.path.join(save_path, name, "inference")
372415
export_single_model(
373-
model.model_list[idx], archs[idx], sub_model_save_path, logger
416+
model.model_list[idx],
417+
archs[idx],
418+
sub_model_save_path,
419+
logger,
420+
yaml_path,
374421
)
375422
else:
376423
save_path = os.path.join(save_path, "inference")
377424
export_single_model(
378-
model, arch_config, save_path, logger, input_shape=input_shape
425+
model, arch_config, save_path, logger, yaml_path, input_shape=input_shape
379426
)
380-
dump_infer_config(config, yaml_path, logger)

0 commit comments

Comments
 (0)