@@ -1295,28 +1295,17 @@ def export(self, output_dir='output_inference', for_fd=False):
1295
1295
1296
1296
# dy2st and save model
1297
1297
if 'slim' not in self .cfg or 'QAT' not in self .cfg ['slim_type' ]:
1298
- paddle_version = version .parse (paddle .__version__ )
1299
- if (paddle_version >= version .parse (
1300
- '3.0.0b2' ) or paddle_version == version .parse ('0.0.0' )) and os .environ .get ("FLAGS_enable_pir_api" , None ) not in ["0" , "False" ]:
1301
- for enable_pir in [True , False ]:
1302
- if not enable_pir :
1303
- static_model .forward .rollback ()
1304
- with paddle .pir_utils .OldIrGuard ():
1305
- save_path_no_pir = save_dir
1306
- static_model , pruned_input_spec = self ._model_to_static (
1307
- model , input_spec )
1308
- paddle .jit .save (static_model , os .path .join (save_path_no_pir , save_name ), input_spec = pruned_input_spec )
1309
- else :
1310
- save_path_pir = os .path .join (os .path .dirname (save_dir ), f"{ os .path .basename (save_dir )} _pir" )
1311
- paddle .jit .save (static_model , os .path .join (save_path_pir , save_name ), input_spec = pruned_input_spec )
1312
- shutil .copy (
1313
- os .path .join (save_dir , yaml_name ),
1314
- os .path .join (
1315
- save_path_pir , yaml_name
1316
- ),
1317
- )
1318
- else :
1298
+ if self .cfg .get ("export_with_pir" , False ):
1299
+ paddle_version = version .parse (paddle .__version__ )
1300
+ assert (paddle_version >= version .parse (
1301
+ '3.0.0b2' ) or paddle_version == version .parse ('0.0.0' )) and os .environ .get ("FLAGS_enable_pir_api" , None ) not in ["0" , "False" ]
1319
1302
paddle .jit .save (static_model , os .path .join (save_dir , save_name ), input_spec = pruned_input_spec )
1303
+ else :
1304
+ static_model .forward .rollback ()
1305
+ with paddle .pir_utils .OldIrGuard ():
1306
+ static_model , pruned_input_spec = self ._model_to_static (
1307
+ model , input_spec )
1308
+ paddle .jit .save (static_model , os .path .join (save_dir , save_name ), input_spec = pruned_input_spec )
1320
1309
else :
1321
1310
self .cfg .slim .save_quantized_model (
1322
1311
self .model ,
0 commit comments