Skip to content

Commit cdbf7ad

Browse files
committed
modify export with pir
1 parent b15da9a commit cdbf7ad

File tree

1 file changed

+10
-21
lines changed

1 file changed

+10
-21
lines changed

ppdet/engine/trainer.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,28 +1295,17 @@ def export(self, output_dir='output_inference', for_fd=False):
12951295

12961296
# dy2st and save model
12971297
if 'slim' not in self.cfg or 'QAT' not in self.cfg['slim_type']:
1298-
paddle_version = version.parse(paddle.__version__)
1299-
if (paddle_version >= version.parse(
1300-
'3.0.0b2') or paddle_version == version.parse('0.0.0')) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"]:
1301-
for enable_pir in [True, False]:
1302-
if not enable_pir:
1303-
static_model.forward.rollback()
1304-
with paddle.pir_utils.OldIrGuard():
1305-
save_path_no_pir = save_dir
1306-
static_model, pruned_input_spec = self._model_to_static(
1307-
model, input_spec)
1308-
paddle.jit.save(static_model, os.path.join(save_path_no_pir, save_name), input_spec=pruned_input_spec)
1309-
else:
1310-
save_path_pir = os.path.join(os.path.dirname(save_dir), f"{os.path.basename(save_dir)}_pir")
1311-
paddle.jit.save(static_model, os.path.join(save_path_pir, save_name), input_spec=pruned_input_spec)
1312-
shutil.copy(
1313-
os.path.join(save_dir, yaml_name),
1314-
os.path.join(
1315-
save_path_pir, yaml_name
1316-
),
1317-
)
1318-
else:
1298+
if self.cfg.get("export_with_pir", False):
1299+
paddle_version = version.parse(paddle.__version__)
1300+
assert (paddle_version >= version.parse(
1301+
'3.0.0b2') or paddle_version == version.parse('0.0.0')) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"]
13191302
paddle.jit.save(static_model, os.path.join(save_dir, save_name), input_spec=pruned_input_spec)
1303+
else:
1304+
static_model.forward.rollback()
1305+
with paddle.pir_utils.OldIrGuard():
1306+
static_model, pruned_input_spec = self._model_to_static(
1307+
model, input_spec)
1308+
paddle.jit.save(static_model, os.path.join(save_dir, save_name), input_spec=pruned_input_spec)
13201309
else:
13211310
self.cfg.slim.save_quantized_model(
13221311
self.model,

0 commit comments

Comments
 (0)