@@ -606,37 +606,25 @@ def export(self,
606
606
model .base_model .quanter .save_quantized_model (model ,
607
607
save_path + "_int8" )
608
608
else :
609
- paddle_version = version .parse (paddle .__version__ )
610
- if (paddle_version >= version .parse ('3.0.0b2' ) or paddle_version ==
611
- version .parse ('0.0.0' )) and os .environ .get (
612
- "FLAGS_enable_pir_api" , None ) not in ["0" , "False" ]:
613
- save_path = os .path .dirname (save_path )
614
- for enable_pir in [True , False ]:
615
- if not enable_pir :
616
- save_path_no_pir = os .path .join (save_path , "inference" )
617
- model .forward .rollback ()
618
- with paddle .pir_utils .OldIrGuard ():
619
- model = paddle .jit .to_static (
620
- model ,
621
- input_spec = [
622
- paddle .static .InputSpec (
623
- shape = [None ] +
624
- self .config ["Global" ]["image_shape" ],
625
- dtype = 'float32' )
626
- ])
627
- paddle .jit .save (model , save_path_no_pir )
628
- else :
629
- save_path_pir = os .path .join (
630
- os .path .dirname (save_path ),
631
- f"{ os .path .basename (save_path )} _pir" , "inference" )
632
- paddle .jit .save (model , save_path_pir )
633
- shutil .copy (
634
- dst_path ,
635
- os .path .join (
636
- os .path .dirname (save_path_pir ),
637
- os .path .basename (dst_path )), )
638
- else :
609
+ if self .config ["Global" ].get ("export_with_pir" , False ):
610
+ paddle_version = version .parse (paddle .__version__ )
611
+ assert (paddle_version >= version .parse ('3.0.0b2' ) or
612
+ paddle_version == version .parse ('0.0.0' )
613
+ ) and os .environ .get ("FLAGS_enable_pir_api" ,
614
+ None ) not in ["0" , "False" ]
639
615
paddle .jit .save (model , save_path )
616
+ else :
617
+ model .forward .rollback ()
618
+ with paddle .pir_utils .OldIrGuard ():
619
+ model = paddle .jit .to_static (
620
+ model ,
621
+ input_spec = [
622
+ paddle .static .InputSpec (
623
+ shape = [None ] +
624
+ self .config ["Global" ]["image_shape" ],
625
+ dtype = 'float32' )
626
+ ])
627
+ paddle .jit .save (model , save_path )
640
628
logger .info (
641
629
f"Export succeeded! The inference model exported has been saved in \" { save_path } \" ."
642
630
)
0 commit comments