Skip to content

Commit b153f10

Browse files
update hpi config (#14076)
1 parent eb92f24 commit b153f10

File tree

1 file changed

+37
-39
lines changed

1 file changed

+37
-39
lines changed

ppocr/utils/export_model.py

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -39,42 +39,42 @@ def setup_orderdict():
3939
def dump_infer_config(config, path, logger):
4040
setup_orderdict()
4141
infer_cfg = OrderedDict()
42-
if config["Global"].get("hpi_config_path", None):
43-
hpi_config = yaml.safe_load(open(config["Global"]["hpi_config_path"], "r"))
44-
rec_resize_img_dict = next(
45-
(
46-
item
47-
for item in config["Eval"]["dataset"]["transforms"]
48-
if "RecResizeImg" in item
49-
),
50-
None,
51-
)
52-
if rec_resize_img_dict:
53-
dynamic_shapes = [1] + rec_resize_img_dict["RecResizeImg"]["image_shape"]
54-
if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None):
55-
hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][
56-
"dynamic_shapes"
57-
]["x"] = [dynamic_shapes for i in range(3)]
58-
hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][
59-
"max_batch_size"
60-
] = 1
61-
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
62-
hpi_config["Hpi"]["backend_config"]["tensorrt"]["dynamic_shapes"][
63-
"x"
64-
] = [dynamic_shapes for i in range(3)]
65-
hpi_config["Hpi"]["backend_config"]["tensorrt"]["max_batch_size"] = 1
66-
else:
67-
if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None):
68-
hpi_config["Hpi"]["supported_backends"]["gpu"].remove("paddle_tensorrt")
69-
del hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"]
70-
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
71-
hpi_config["Hpi"]["supported_backends"]["gpu"].remove("tensorrt")
72-
del hpi_config["Hpi"]["backend_config"]["tensorrt"]
73-
hpi_config["Hpi"]["selected_backends"]["gpu"] = "paddle_infer"
74-
infer_cfg["Hpi"] = hpi_config["Hpi"]
7542
if config["Global"].get("pdx_model_name", None):
76-
infer_cfg["Global"] = {}
77-
infer_cfg["Global"]["model_name"] = config["Global"]["pdx_model_name"]
43+
infer_cfg["Global"] = {"model_name": config["Global"]["pdx_model_name"]}
44+
if config["Global"].get("uniform_output_enabled", None):
45+
arch_config = config["Architecture"]
46+
if arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
47+
common_dynamic_shapes = {
48+
"x": [[1, 3, 48, 320], [1, 3, 48, 320], [8, 3, 48, 320]]
49+
}
50+
elif arch_config["model_type"] == "det":
51+
common_dynamic_shapes = {
52+
"x": [[1, 3, 160, 160], [1, 3, 160, 160], [1, 3, 1280, 1280]]
53+
}
54+
elif arch_config["algorithm"] == "SLANet":
55+
common_dynamic_shapes = {
56+
"x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 192, 672]]
57+
}
58+
elif arch_config["algorithm"] == "LaTeXOCR":
59+
common_dynamic_shapes = {
60+
"x": [[1, 3, 224, 224], [1, 3, 448, 448], [8, 3, 1280, 1280]]
61+
}
62+
else:
63+
common_dynamic_shapes = None
64+
65+
backend_keys = ["paddle_infer", "tensorrt"]
66+
hpi_config = {
67+
"backend_configs": {
68+
key: {
69+
(
70+
"dynamic_shapes" if key == "tensorrt" else "trt_dynamic_shapes"
71+
): common_dynamic_shapes
72+
}
73+
for key in backend_keys
74+
}
75+
}
76+
if common_dynamic_shapes:
77+
infer_cfg["Hpi"] = hpi_config
7878

7979
infer_cfg["PreProcess"] = {"transform_ops": config["Eval"]["dataset"]["transforms"]}
8080
postprocess = OrderedDict()
@@ -96,10 +96,8 @@ def dump_infer_config(config, path, logger):
9696

9797
infer_cfg["PostProcess"] = postprocess
9898

99-
with open(path, "w") as f:
100-
yaml.dump(
101-
infer_cfg, f, default_flow_style=False, encoding="utf-8", allow_unicode=True
102-
)
99+
with open(path, "w", encoding="utf-8") as f:
100+
yaml.dump(infer_cfg, f, default_flow_style=False, allow_unicode=True)
103101
logger.info("Export inference config file to {}".format(os.path.join(path)))
104102

105103

0 commit comments

Comments
 (0)