Skip to content

Commit 6256dd7

Browse files
【cherry-pick】export with label (#3167)
* export with label * modify
1 parent 889be4f commit 6256dd7

File tree

3 files changed

+52
-5
lines changed

3 files changed

+52
-5
lines changed

deploy/python/postprocess.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,16 @@ def __call__(self, x, file_names=None):
164164

165165

166166
class Topk(object):
167-
def __init__(self, topk=1, class_id_map_file=None, delimiter=None):
167+
def __init__(self,
168+
topk=1,
169+
class_id_map_file=None,
170+
delimiter=None,
171+
label_list=None):
168172
assert isinstance(topk, (int, ))
169173
self.topk = topk
170174
delimiter = delimiter if delimiter is not None else " "
171-
self.class_id_map = parse_class_id_map(class_id_map_file, delimiter)
175+
self.class_id_map = parse_class_id_map(
176+
class_id_map_file, delimiter) if not label_list else label_list
172177

173178
def __call__(self, x, file_names=None):
174179
if file_names is not None:

ppcls/engine/engine.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ppcls.utils.misc import AverageMeter
2828
from ppcls.utils import logger
2929
from ppcls.utils.logger import init_logger
30-
from ppcls.utils.config import print_config
30+
from ppcls.utils.config import print_config, dump_infer_config
3131
from ppcls.data import build_dataloader
3232
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
3333
from ppcls.arch import apply_to_static
@@ -523,10 +523,9 @@ def export(self):
523523
else:
524524
paddle.jit.save(model, save_path)
525525
if self.config["Global"].get("export_for_fd", False):
526-
src_path = self.config["Global"]["infer_config_path"]
527526
dst_path = os.path.join(
528527
self.config["Global"]["save_inference_dir"], 'inference.yml')
529-
shutil.copy(src_path, dst_path)
528+
dump_infer_config(self.config, dst_path)
530529
logger.info(
531530
f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
532531
)

ppcls/utils/config.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import yaml
1919
from . import logger
2020
from . import check
21+
from collections import OrderedDict
2122

2223
__all__ = ['get_config']
2324

@@ -213,3 +214,45 @@ def parse_args():
213214
)
214215
args = parser.parse_args()
215216
return args
217+
218+
219+
def represent_dictionary_order(self, dict_data):
220+
return self.represent_mapping('tag:yaml.org,2002:map', dict_data.items())
221+
222+
223+
def setup_orderdict():
224+
yaml.add_representer(OrderedDict, represent_dictionary_order)
225+
226+
227+
def dump_infer_config(config, path):
228+
setup_orderdict()
229+
infer_cfg = OrderedDict()
230+
transforms = config["Infer"]["transforms"]
231+
for transform in transforms:
232+
if "NormalizeImage" in transform:
233+
transform["NormalizeImage"]["channel_num"] = 3
234+
infer_cfg["PreProcess"] = {
235+
"transform_ops": [
236+
infer_preprocess for infer_preprocess in transforms
237+
if "DecodeImage" not in infer_preprocess
238+
]
239+
}
240+
241+
postprocess_dict = config["Infer"]["PostProcess"]
242+
with open(postprocess_dict["class_id_map_file"], 'r') as f:
243+
label_id_maps = f.readlines()
244+
label_names = []
245+
for line in label_id_maps:
246+
line = line.strip().split(' ', 1)
247+
label_names.append(line[1:][0])
248+
249+
infer_cfg["PostProcess"] = {
250+
"Topk": OrderedDict({
251+
"topk": postprocess_dict["topk"],
252+
"label_list": label_names
253+
})
254+
}
255+
with open(path, 'w') as f:
256+
yaml.dump(infer_cfg, f)
257+
logger.info("Export inference config file to {}".format(
258+
os.path.join(path)))

0 commit comments

Comments
 (0)