21
21
import os
22
22
import sys
23
23
from os import path as osp
24
+ from typing import TYPE_CHECKING
24
25
from typing import Callable
25
26
from typing import Dict
26
27
from typing import List
41
42
from paddle import optimizer as optim
42
43
from paddle .distributed import fleet
43
44
from paddle .framework import core
44
- from paddle .static import InputSpec
45
45
from typing_extensions import Literal
46
46
47
47
import ppsci
51
51
from ppsci .utils import misc
52
52
from ppsci .utils import save_load
53
53
54
+ if TYPE_CHECKING :
55
+ from paddle .static import InputSpec
56
+
54
57
55
58
class Solver :
56
59
"""Class for solver.
@@ -729,7 +732,11 @@ def predict(
729
732
730
733
@misc .run_on_eval_mode
731
734
def export (
732
- self , input_spec : List [InputSpec ], export_path : str , with_onnx : bool = False
735
+ self ,
736
+ input_spec : List ["InputSpec" ],
737
+ export_path : str ,
738
+ with_onnx : bool = False ,
739
+ skip_prune_program : bool = False ,
733
740
):
734
741
"""
735
742
Convert model to static graph model and export to files.
@@ -740,6 +747,8 @@ def export(
740
747
export_path (str): The path prefix to save model.
741
748
with_onnx (bool, optional): Whether to export model into onnx after
742
749
paddle inference models are exported.
750
+ skip_prune_program (bool, optional): Whether prune program, pruning program
751
+ may cause unexpectable result, e.g. llm-inference.
743
752
"""
744
753
jit .enable_to_static (True )
745
754
@@ -760,7 +769,7 @@ def export(
760
769
if len (osp .dirname (export_path )):
761
770
os .makedirs (osp .dirname (export_path ), exist_ok = True )
762
771
try :
763
- jit .save (static_model , export_path )
772
+ jit .save (static_model , export_path , skip_prune_program = skip_prune_program )
764
773
except Exception as e :
765
774
raise e
766
775
logger .message (
0 commit comments