Skip to content

Commit a9a730e

Browse files
zhangyubo0722TingquanGao
authored andcommitted
support uniform output
1 parent a7b2356 commit a9a730e

File tree

6 files changed

+247
-47
lines changed

6 files changed

+247
-47
lines changed

ppcls/engine/engine.py

Lines changed: 97 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import os
1818
import shutil
19+
import copy
1920
import platform
2021
import paddle
2122
import paddle.distributed as dist
@@ -38,6 +39,7 @@
3839
from ppcls.utils.ema import ExponentialMovingAverage
3940
from ppcls.utils.save_load import load_dygraph_pretrain
4041
from ppcls.utils.save_load import init_model
42+
from ppcls.utils.save_result import update_train_results
4143
from ppcls.utils import save_load, save_predict_result
4244

4345
from ppcls.data.utils.get_image_list import get_image_list
@@ -169,8 +171,8 @@ def __init__(self, config, mode="train"):
169171
self.config["DataLoader"]["Eval"], "Gallery",
170172
self.device, self.use_dali)
171173
self.query_dataloader = build_dataloader(
172-
self.config["DataLoader"]["Eval"], "Query",
173-
self.device, self.use_dali)
174+
self.config["DataLoader"]["Eval"], "Query", self.device,
175+
self.use_dali)
174176

175177
# build loss
176178
if self.mode == "train":
@@ -210,8 +212,8 @@ def __init__(self, config, mode="train"):
210212
self.config["Global"]["eval_during_train"]):
211213
if self.eval_mode == "classification":
212214
if "Metric" in self.config and "Eval" in self.config["Metric"]:
213-
self.eval_metric_func = build_metrics(self.config["Metric"]
214-
["Eval"])
215+
self.eval_metric_func = build_metrics(self.config["Metric"][
216+
"Eval"])
215217
else:
216218
self.eval_metric_func = None
217219
elif self.eval_mode == "retrieval":
@@ -266,8 +268,7 @@ def __init__(self, config, mode="train"):
266268
self.model = paddle.DataParallel(self.model)
267269
if self.mode == 'train' and len(self.train_loss_func.parameters(
268270
)) > 0:
269-
self.train_loss_func = paddle.DataParallel(
270-
self.train_loss_func)
271+
self.train_loss_func = paddle.DataParallel(self.train_loss_func)
271272

272273
# set different seed in different GPU manually in distributed environment
273274
if seed is None:
@@ -313,6 +314,8 @@ def train(self):
313314
}
314315
# global iter counter
315316
self.global_step = 0
317+
uniform_output_enabled = self.config['Global'].get(
318+
"uniform_output_enabled", False)
316319

317320
if self.config.Global.checkpoints is not None:
318321
metric_info = init_model(self.config.Global, self.model,
@@ -384,41 +387,89 @@ def train(self):
384387
# save best model from best_acc or best_ema_acc
385388
if max(acc, acc_ema) >= max(best_metric["metric"],
386389
best_metric_ema):
390+
metric_info = {
391+
"metric": max(acc, acc_ema),
392+
"epoch": epoch_id
393+
}
394+
prefix = "best_model"
387395
save_load.save_model(
388396
self.model,
389397
self.optimizer,
390-
{"metric": max(acc, acc_ema),
391-
"epoch": epoch_id},
392-
self.output_dir,
398+
metric_info,
399+
os.path.join(self.output_dir, prefix)
400+
if uniform_output_enabled else self.output_dir,
393401
ema=ema_module,
394402
model_name=self.config["Arch"]["name"],
395-
prefix="best_model",
403+
prefix=prefix,
396404
loss=self.train_loss_func,
397405
save_student_model=True)
406+
if uniform_output_enabled:
407+
save_path = os.path.join(self.output_dir, prefix,
408+
"inference")
409+
self.export(save_path, uniform_output_enabled)
410+
if self.ema:
411+
ema_save_path = os.path.join(
412+
self.output_dir, prefix, "inference_ema")
413+
self.export(ema_save_path, uniform_output_enabled)
414+
update_train_results(
415+
self.config, prefix, metric_info, ema=self.ema)
416+
save_load.save_model_info(metric_info, self.output_dir,
417+
prefix)
398418

399419
self.model.train()
400420

401421
# save model
402422
if save_interval > 0 and epoch_id % save_interval == 0:
423+
metric_info = {"metric": acc, "epoch": epoch_id}
424+
prefix = "epoch_{}".format(epoch_id)
403425
save_load.save_model(
404426
self.model,
405-
self.optimizer, {"metric": acc,
406-
"epoch": epoch_id},
407-
self.output_dir,
427+
self.optimizer,
428+
metric_info,
429+
os.path.join(self.output_dir, prefix)
430+
if uniform_output_enabled else self.output_dir,
408431
ema=ema_module,
409432
model_name=self.config["Arch"]["name"],
410-
prefix="epoch_{}".format(epoch_id),
433+
prefix=prefix,
411434
loss=self.train_loss_func)
435+
if uniform_output_enabled:
436+
save_path = os.path.join(self.output_dir, prefix,
437+
"inference")
438+
self.export(save_path, uniform_output_enabled)
439+
if self.ema:
440+
ema_save_path = os.path.join(self.output_dir, prefix,
441+
"inference_ema")
442+
self.export(ema_save_path, uniform_output_enabled)
443+
update_train_results(
444+
self.config,
445+
prefix,
446+
metric_info,
447+
done_flag=epoch_id == self.config["Global"]["epochs"],
448+
ema=self.ema)
449+
save_load.save_model_info(metric_info, self.output_dir,
450+
prefix)
412451
# save the latest model
452+
metric_info = {"metric": acc, "epoch": epoch_id}
453+
prefix = "latest"
413454
save_load.save_model(
414455
self.model,
415-
self.optimizer, {"metric": acc,
416-
"epoch": epoch_id},
417-
self.output_dir,
456+
self.optimizer,
457+
metric_info,
458+
os.path.join(self.output_dir, prefix)
459+
if uniform_output_enabled else self.output_dir,
418460
ema=ema_module,
419461
model_name=self.config["Arch"]["name"],
420-
prefix="latest",
462+
prefix=prefix,
421463
loss=self.train_loss_func)
464+
if uniform_output_enabled:
465+
save_path = os.path.join(self.output_dir, prefix, "inference")
466+
self.export(save_path, uniform_output_enabled)
467+
if self.ema:
468+
ema_save_path = os.path.join(self.output_dir, prefix,
469+
"inference_ema")
470+
self.export(ema_save_path, uniform_output_enabled)
471+
save_load.save_model_info(metric_info, self.output_dir, prefix)
472+
self.model.train()
422473

423474
if self.vdl_writer is not None:
424475
self.vdl_writer.close()
@@ -479,33 +530,45 @@ def infer(self):
479530
image_file_list.clear()
480531
except Exception as ex:
481532
logger.error(
482-
"Exception occured when parse line: {} with msg: {}".
483-
format(image_file, ex))
533+
"Exception occured when parse line: {} with msg: {}".format(
534+
image_file, ex))
484535
continue
485536
if save_path:
486537
save_predict_result(save_path, results)
487538
return results
488539

489-
def export(self):
490-
assert self.mode == "export"
540+
def export(self,
541+
save_path=None,
542+
uniform_output_enabled=False,
543+
ema_module=None):
544+
assert self.mode == "export" or uniform_output_enabled
545+
if paddle.distributed.get_rank() != 0:
546+
return
491547
use_multilabel = self.config["Global"].get(
492548
"use_multilabel",
493549
False) or "ATTRMetric" in self.config["Metric"]["Eval"][0]
494-
model = ExportModel(self.config["Arch"], self.model, use_multilabel)
495-
if self.config["Global"]["pretrained_model"] is not None:
550+
model = self.model_ema.module if self.ema else self.model
551+
if isinstance(self.model, paddle.DataParallel):
552+
model = copy.deepcopy(model._layers)
553+
else:
554+
model = copy.deepcopy(model)
555+
model = ExportModel(self.config["Arch"], model
556+
if not ema_module else ema_module, use_multilabel)
557+
if self.config["Global"][
558+
"pretrained_model"] is not None and not uniform_output_enabled:
496559
load_dygraph_pretrain(model.base_model,
497560
self.config["Global"]["pretrained_model"])
498-
499561
model.eval()
500-
501562
# for re-parameterization nets
502-
for layer in self.model.sublayers():
563+
for layer in model.sublayers():
503564
if hasattr(layer, "re_parameterize") and not getattr(layer,
504565
"is_repped"):
505566
layer.re_parameterize()
506-
507-
save_path = os.path.join(self.config["Global"]["save_inference_dir"],
508-
"inference")
567+
if not save_path:
568+
save_path = os.path.join(
569+
self.config["Global"]["save_inference_dir"], "inference")
570+
else:
571+
save_path = os.path.join(save_path, "inference")
509572

510573
model = paddle.jit.to_static(
511574
model,
@@ -520,12 +583,12 @@ def export(self):
520583
save_path + "_int8")
521584
else:
522585
paddle.jit.save(model, save_path)
523-
if self.config["Global"].get("export_for_fd", False):
524-
dst_path = os.path.join(
525-
self.config["Global"]["save_inference_dir"], 'inference.yml')
586+
if self.config["Global"].get("export_for_fd",
587+
False) or uniform_output_enabled:
588+
dst_path = os.path.join(os.path.dirname(save_path), 'inference.yml')
526589
dump_infer_config(self.config, dst_path)
527590
logger.info(
528-
f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
591+
f"Export succeeded! The inference model exported has been saved in \"{save_path}\"."
529592
)
530593

531594
def _init_amp(self):

ppcls/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from . import misc
1818
from . import model_zoo
1919

20-
from .config import get_config
20+
from .config import get_config, convert_to_dict
2121
from .dist_utils import all_gather
2222
from .metrics import accuracy_score
2323
from .metrics import hamming_distance

ppcls/utils/config.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,16 @@
2020
from . import check
2121
from collections import OrderedDict
2222

23-
__all__ = ['get_config']
23+
__all__ = ['get_config', 'convert_to_dict']
24+
25+
26+
def convert_to_dict(obj):
27+
if isinstance(obj, dict):
28+
return {k: convert_to_dict(v) for k, v in obj.items()}
29+
elif isinstance(obj, list):
30+
return [convert_to_dict(i) for i in obj]
31+
else:
32+
return obj
2433

2534

2635
class AttrDict(dict):
@@ -223,16 +232,49 @@ def setup_orderdict():
223232
yaml.add_representer(OrderedDict, represent_dictionary_order)
224233

225234

226-
def dump_infer_config(config, path):
235+
def dump_infer_config(inference_config, path):
227236
setup_orderdict()
228237
infer_cfg = OrderedDict()
238+
config = copy.deepcopy(inference_config)
229239
if config.get("Infer"):
230240
transforms = config["Infer"]["transforms"]
231241
elif config["DataLoader"]["Eval"].get("Query"):
232-
transforms = config["DataLoader"]["Eval"]["Query"]["dataset"]["transform_ops"]
242+
transforms = config["DataLoader"]["Eval"]["Query"]["dataset"][
243+
"transform_ops"]
233244
transforms.append({"ToCHWImage": None})
234245
else:
235246
logger.error("This config does not support dump transform config!")
247+
transform = next((item for item in transforms if 'CropImage' in item), None)
248+
if transform:
249+
dynamic_shapes = transform["CropImage"]["size"]
250+
else:
251+
transform = next((item for item in transforms
252+
if 'ResizeImage' in item), None)
253+
if transform:
254+
dynamic_shapes = transform["ResizeImage"]["size"][0]
255+
else:
256+
dynamic_shapes = 224
257+
# Configuration required config for high-performance inference.
258+
if config["Global"].get("hpi_config_path", None):
259+
hpi_config = convert_to_dict(
260+
parse_config(config["Global"]["hpi_config_path"]))
261+
if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None):
262+
hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][
263+
"dynamic_shapes"]["x"] = [[
264+
1, 3, dynamic_shapes, dynamic_shapes
265+
] for i in range(3)]
266+
hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][
267+
"max_batch_size"] = 1
268+
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
269+
hpi_config["Hpi"]["backend_config"]["tensorrt"]["dynamic_shapes"][
270+
"x"] = [[1, 3, dynamic_shapes, dynamic_shapes]
271+
for i in range(3)]
272+
hpi_config["Hpi"]["backend_config"]["tensorrt"][
273+
"max_batch_size"] = 1
274+
infer_cfg["Hpi"] = hpi_config["Hpi"]
275+
if config["Global"].get("pdx_model_name", None):
276+
infer_cfg["Global"] = {}
277+
infer_cfg["Global"]["model_name"] = config["Global"]["pdx_model_name"]
236278
for transform in transforms:
237279
if "NormalizeImage" in transform:
238280
transform["NormalizeImage"]["channel_num"] = 3
@@ -262,7 +304,7 @@ def dump_infer_config(config, path):
262304
postprocess_dict.pop("name")
263305
dic = OrderedDict()
264306
for item in postprocess_dict.items():
265-
dic[item[0]] = item[1]
307+
dic[item[0]] = item[1]
266308
dic['label_list'] = label_names
267309

268310
if postprocess_name:

ppcls/utils/save_load.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import errno
2020
import os
21+
import json
2122

2223
import paddle
2324
from . import logger
@@ -108,8 +109,7 @@ def load_distillation_model(model, pretrained_model):
108109
student = model.student if hasattr(model,
109110
"student") else model._layers.student
110111
load_dygraph_pretrain(teacher, path=pretrained_model[0])
111-
logger.info("Finish initing teacher model from {}".format(
112-
pretrained_model))
112+
logger.info("Finish initing teacher model from {}".format(pretrained_model))
113113
# load student model
114114
if len(pretrained_model) >= 2:
115115
load_dygraph_pretrain(student, path=pretrained_model[1])
@@ -188,8 +188,7 @@ def save_model(net,
188188
params_state_dict = net.state_dict()
189189
if loss is not None:
190190
loss_state_dict = loss.state_dict()
191-
keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys(
192-
))
191+
keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys())
193192
assert len(keys_inter) == 0, \
194193
f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
195194
params_state_dict.update(loss_state_dict)
@@ -210,3 +209,15 @@ def save_model(net,
210209
paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt")
211210
paddle.save(metric_info, model_path + ".pdstates")
212211
logger.info("Already save model in {}".format(model_path))
212+
213+
214+
def save_model_info(model_info, save_path, prefix):
215+
"""
216+
save model info to the target path
217+
"""
218+
save_path = os.path.join(save_path, prefix)
219+
if not os.path.exists(save_path):
220+
os.makedirs(save_path)
221+
with open(os.path.join(save_path, f'{prefix}.info.json'), 'w') as f:
222+
json.dump(model_info, f)
223+
logger.info("Already save model info in {}".format(save_path))

0 commit comments

Comments
 (0)