Skip to content

Commit 0e90ae5

Browse files
modify export with pir (#3332)
1 parent bd23af2 commit 0e90ae5

File tree

1 file changed

+18
-30
lines changed

1 file changed

+18
-30
lines changed

ppcls/engine/engine.py

+18-30
Original file line numberDiff line numberDiff line change
@@ -606,37 +606,25 @@ def export(self,
606606
model.base_model.quanter.save_quantized_model(model,
607607
save_path + "_int8")
608608
else:
609-
paddle_version = version.parse(paddle.__version__)
610-
if (paddle_version >= version.parse('3.0.0b2') or paddle_version ==
611-
version.parse('0.0.0')) and os.environ.get(
612-
"FLAGS_enable_pir_api", None) not in ["0", "False"]:
613-
save_path = os.path.dirname(save_path)
614-
for enable_pir in [True, False]:
615-
if not enable_pir:
616-
save_path_no_pir = os.path.join(save_path, "inference")
617-
model.forward.rollback()
618-
with paddle.pir_utils.OldIrGuard():
619-
model = paddle.jit.to_static(
620-
model,
621-
input_spec=[
622-
paddle.static.InputSpec(
623-
shape=[None] +
624-
self.config["Global"]["image_shape"],
625-
dtype='float32')
626-
])
627-
paddle.jit.save(model, save_path_no_pir)
628-
else:
629-
save_path_pir = os.path.join(
630-
os.path.dirname(save_path),
631-
f"{os.path.basename(save_path)}_pir", "inference")
632-
paddle.jit.save(model, save_path_pir)
633-
shutil.copy(
634-
dst_path,
635-
os.path.join(
636-
os.path.dirname(save_path_pir),
637-
os.path.basename(dst_path)), )
638-
else:
609+
if self.config["Global"].get("export_with_pir", False):
610+
paddle_version = version.parse(paddle.__version__)
611+
assert (paddle_version >= version.parse('3.0.0b2') or
612+
paddle_version == version.parse('0.0.0')
613+
) and os.environ.get("FLAGS_enable_pir_api",
614+
None) not in ["0", "False"]
639615
paddle.jit.save(model, save_path)
616+
else:
617+
model.forward.rollback()
618+
with paddle.pir_utils.OldIrGuard():
619+
model = paddle.jit.to_static(
620+
model,
621+
input_spec=[
622+
paddle.static.InputSpec(
623+
shape=[None] +
624+
self.config["Global"]["image_shape"],
625+
dtype='float32')
626+
])
627+
paddle.jit.save(model, save_path)
640628
logger.info(
641629
f"Export succeeded! The inference model exported has been saved in \"{save_path}\"."
642630
)

0 commit comments

Comments
 (0)