Skip to content

Commit 1f3ce3a

Browse files
fix save two infer models (#9250)
1 parent 3ba5d74 commit 1f3ce3a

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

ppdet/engine/trainer.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,21 @@ def _get_save_image_name(self, output_dir, image_path):
11451145
name, ext = os.path.splitext(image_name)
11461146
return os.path.join(output_dir, "{}".format(name)) + ext
11471147

1148+
def _model_to_static(self, model, input_spec, prune_input=True):
1149+
if prune_input:
1150+
static_model = paddle.jit.to_static(
1151+
model, input_spec=input_spec, full_graph=True)
1152+
# NOTE: dy2st do not pruned program, but jit.save will prune program
1153+
# input spec, prune input spec here and save with pruned input spec
1154+
pruned_input_spec = _prune_input_spec(
1155+
input_spec, static_model.forward.main_program,
1156+
static_model.forward.outputs)
1157+
else:
1158+
static_model = None
1159+
pruned_input_spec = input_spec
1160+
1161+
return static_model, pruned_input_spec
1162+
11481163
def _get_infer_cfg_and_input_spec(self,
11491164
save_dir,
11501165
prune_input=True,
@@ -1226,18 +1241,8 @@ def _get_infer_cfg_and_input_spec(self,
12261241
"full_img_path": str,
12271242
"img_name": str,
12281243
})
1229-
if prune_input:
12301244

1231-
static_model = paddle.jit.to_static(
1232-
model, input_spec=input_spec, full_graph=True)
1233-
# NOTE: dy2st do not pruned program, but jit.save will prune program
1234-
# input spec, prune input spec here and save with pruned input spec
1235-
pruned_input_spec = _prune_input_spec(
1236-
input_spec, static_model.forward.main_program,
1237-
static_model.forward.outputs)
1238-
else:
1239-
static_model = None
1240-
pruned_input_spec = input_spec
1245+
static_model, pruned_input_spec = self._model_to_static(model, input_spec, prune_input)
12411246

12421247
# TODO: Hard code, delete it when support prune input_spec.
12431248
if self.cfg.architecture == 'PicoDet' and not export_post_process:
@@ -1259,7 +1264,7 @@ def _get_infer_cfg_and_input_spec(self,
12591264
shape=image_shape, name='image')
12601265
}]
12611266

1262-
return static_model, pruned_input_spec
1267+
return static_model, pruned_input_spec, input_spec
12631268

12641269
def export(self, output_dir='output_inference', for_fd=False):
12651270
if hasattr(self.model, 'aux_neck'):
@@ -1285,8 +1290,7 @@ def export(self, output_dir='output_inference', for_fd=False):
12851290

12861291
if not os.path.exists(save_dir):
12871292
os.makedirs(save_dir)
1288-
1289-
static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec(
1293+
static_model, pruned_input_spec, input_spec = self._get_infer_cfg_and_input_spec(
12901294
save_dir, yaml_name=yaml_name, model=model)
12911295

12921296
# dy2st and save model
@@ -1299,8 +1303,8 @@ def export(self, output_dir='output_inference', for_fd=False):
12991303
static_model.forward.rollback()
13001304
with paddle.pir_utils.OldIrGuard():
13011305
save_path_no_pir = save_dir
1302-
static_model, pruned_input_spec, = self._get_infer_cfg_and_input_spec(
1303-
save_dir, yaml_name=yaml_name, model=model)
1306+
static_model, pruned_input_spec = self._model_to_static(
1307+
model, input_spec)
13041308
paddle.jit.save(static_model, os.path.join(save_path_no_pir, save_name), input_spec=pruned_input_spec)
13051309
else:
13061310
save_path_pir = os.path.join(os.path.dirname(save_dir), f"{os.path.basename(save_dir)}_pir")

0 commit comments

Comments
 (0)