Skip to content

Commit f57a8d1

Browse files
add skip_prune_program arg for Solver.export
1 parent 25bb1bd commit f57a8d1

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

ppsci/solver/solver.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import os
2222
import sys
2323
from os import path as osp
24+
from typing import TYPE_CHECKING
2425
from typing import Callable
2526
from typing import Dict
2627
from typing import List
@@ -41,7 +42,6 @@
4142
from paddle import optimizer as optim
4243
from paddle.distributed import fleet
4344
from paddle.framework import core
44-
from paddle.static import InputSpec
4545
from typing_extensions import Literal
4646

4747
import ppsci
@@ -51,6 +51,9 @@
5151
from ppsci.utils import misc
5252
from ppsci.utils import save_load
5353

54+
if TYPE_CHECKING:
55+
from paddle.static import InputSpec
56+
5457

5558
class Solver:
5659
"""Class for solver.
@@ -729,7 +732,11 @@ def predict(
729732

730733
@misc.run_on_eval_mode
731734
def export(
732-
self, input_spec: List[InputSpec], export_path: str, with_onnx: bool = False
735+
self,
736+
input_spec: List["InputSpec"],
737+
export_path: str,
738+
with_onnx: bool = False,
739+
skip_prune_program: bool = False,
733740
):
734741
"""
735742
Convert model to static graph model and export to files.
@@ -740,6 +747,8 @@ def export(
740747
export_path (str): The path prefix to save model.
741748
with_onnx (bool, optional): Whether to export model into onnx after
742749
paddle inference models are exported.
750+
skip_prune_program (bool, optional): Whether prune program, pruning program
751+
may cause unexpectable result, e.g. llm-inference.
743752
"""
744753
jit.enable_to_static(True)
745754

@@ -760,7 +769,7 @@ def export(
760769
if len(osp.dirname(export_path)):
761770
os.makedirs(osp.dirname(export_path), exist_ok=True)
762771
try:
763-
jit.save(static_model, export_path)
772+
jit.save(static_model, export_path, skip_prune_program=skip_prune_program)
764773
except Exception as e:
765774
raise e
766775
logger.message(

0 commit comments

Comments
 (0)