@@ -1145,6 +1145,21 @@ def _get_save_image_name(self, output_dir, image_path):
1145
1145
name , ext = os .path .splitext (image_name )
1146
1146
return os .path .join (output_dir , "{}" .format (name )) + ext
1147
1147
1148
+ def _model_to_static (self , model , input_spec , prune_input = True ):
1149
+ if prune_input :
1150
+ static_model = paddle .jit .to_static (
1151
+ model , input_spec = input_spec , full_graph = True )
1152
+ # NOTE: dy2st do not pruned program, but jit.save will prune program
1153
+ # input spec, prune input spec here and save with pruned input spec
1154
+ pruned_input_spec = _prune_input_spec (
1155
+ input_spec , static_model .forward .main_program ,
1156
+ static_model .forward .outputs )
1157
+ else :
1158
+ static_model = None
1159
+ pruned_input_spec = input_spec
1160
+
1161
+ return static_model , pruned_input_spec
1162
+
1148
1163
def _get_infer_cfg_and_input_spec (self ,
1149
1164
save_dir ,
1150
1165
prune_input = True ,
@@ -1226,18 +1241,8 @@ def _get_infer_cfg_and_input_spec(self,
1226
1241
"full_img_path" : str ,
1227
1242
"img_name" : str ,
1228
1243
})
1229
- if prune_input :
1230
1244
1231
- static_model = paddle .jit .to_static (
1232
- model , input_spec = input_spec , full_graph = True )
1233
- # NOTE: dy2st do not pruned program, but jit.save will prune program
1234
- # input spec, prune input spec here and save with pruned input spec
1235
- pruned_input_spec = _prune_input_spec (
1236
- input_spec , static_model .forward .main_program ,
1237
- static_model .forward .outputs )
1238
- else :
1239
- static_model = None
1240
- pruned_input_spec = input_spec
1245
+ static_model , pruned_input_spec = self ._model_to_static (model , input_spec , prune_input )
1241
1246
1242
1247
# TODO: Hard code, delete it when support prune input_spec.
1243
1248
if self .cfg .architecture == 'PicoDet' and not export_post_process :
@@ -1259,7 +1264,7 @@ def _get_infer_cfg_and_input_spec(self,
1259
1264
shape = image_shape , name = 'image' )
1260
1265
}]
1261
1266
1262
- return static_model , pruned_input_spec
1267
+ return static_model , pruned_input_spec , input_spec
1263
1268
1264
1269
def export (self , output_dir = 'output_inference' , for_fd = False ):
1265
1270
if hasattr (self .model , 'aux_neck' ):
@@ -1285,8 +1290,7 @@ def export(self, output_dir='output_inference', for_fd=False):
1285
1290
1286
1291
if not os .path .exists (save_dir ):
1287
1292
os .makedirs (save_dir )
1288
-
1289
- static_model , pruned_input_spec = self ._get_infer_cfg_and_input_spec (
1293
+ static_model , pruned_input_spec , input_spec = self ._get_infer_cfg_and_input_spec (
1290
1294
save_dir , yaml_name = yaml_name , model = model )
1291
1295
1292
1296
# dy2st and save model
@@ -1299,8 +1303,8 @@ def export(self, output_dir='output_inference', for_fd=False):
1299
1303
static_model .forward .rollback ()
1300
1304
with paddle .pir_utils .OldIrGuard ():
1301
1305
save_path_no_pir = save_dir
1302
- static_model , pruned_input_spec , = self ._get_infer_cfg_and_input_spec (
1303
- save_dir , yaml_name = yaml_name , model = model )
1306
+ static_model , pruned_input_spec = self ._model_to_static (
1307
+ model , input_spec )
1304
1308
paddle .jit .save (static_model , os .path .join (save_path_no_pir , save_name ), input_spec = pruned_input_spec )
1305
1309
else :
1306
1310
save_path_pir = os .path .join (os .path .dirname (save_dir ), f"{ os .path .basename (save_dir )} _pir" )
0 commit comments