16
16
import yaml
17
17
import json
18
18
import copy
19
+ import shutil
19
20
import paddle
20
21
import paddle .nn as nn
21
22
from paddle .jit import to_static
22
23
23
24
from collections import OrderedDict
25
+ from packaging import version
24
26
from argparse import ArgumentParser , RawDescriptionHelpFormatter
25
27
from ppocr .modeling .architectures import build_model
26
28
from ppocr .postprocess import build_post_process
@@ -39,21 +41,23 @@ def setup_orderdict():
39
41
def dump_infer_config (config , path , logger ):
40
42
setup_orderdict ()
41
43
infer_cfg = OrderedDict ()
44
+ if not os .path .exists (os .path .dirname (path )):
45
+ os .makedirs (os .path .dirname (path ))
42
46
if config ["Global" ].get ("pdx_model_name" , None ):
43
47
infer_cfg ["Global" ] = {"model_name" : config ["Global" ]["pdx_model_name" ]}
44
48
if config ["Global" ].get ("uniform_output_enabled" , None ):
45
49
arch_config = config ["Architecture" ]
46
50
if arch_config ["algorithm" ] in ["SVTR_LCNet" , "SVTR_HGNet" ]:
47
51
common_dynamic_shapes = {
48
- "x" : [[1 , 3 , 48 , 320 ], [1 , 3 , 48 , 320 ], [8 , 3 , 48 , 320 ]]
52
+ "x" : [[1 , 3 , 24 , 160 ], [1 , 3 , 48 , 320 ], [8 , 3 , 96 , 640 ]]
49
53
}
50
54
elif arch_config ["model_type" ] == "det" :
51
55
common_dynamic_shapes = {
52
56
"x" : [[1 , 3 , 160 , 160 ], [1 , 3 , 160 , 160 ], [1 , 3 , 1280 , 1280 ]]
53
57
}
54
58
elif arch_config ["algorithm" ] == "SLANet" :
55
59
common_dynamic_shapes = {
56
- "x" : [[1 , 3 , 32 , 32 ], [1 , 3 , 64 , 448 ], [8 , 3 , 192 , 672 ]]
60
+ "x" : [[1 , 3 , 32 , 32 ], [1 , 3 , 64 , 448 ], [8 , 3 , 488 , 488 ]]
57
61
}
58
62
elif arch_config ["algorithm" ] == "LaTeXOCR" :
59
63
common_dynamic_shapes = {
@@ -101,9 +105,7 @@ def dump_infer_config(config, path, logger):
101
105
logger .info ("Export inference config file to {}" .format (os .path .join (path )))
102
106
103
107
104
- def export_single_model (
105
- model , arch_config , save_path , logger , input_shape = None , quanter = None
106
- ):
108
+ def dynamic_to_static (model , arch_config , logger , input_shape = None ):
107
109
if arch_config ["algorithm" ] == "SRN" :
108
110
max_text_length = arch_config ["Head" ]["max_text_length" ]
109
111
other_shape = [
@@ -262,9 +264,50 @@ def export_single_model(
262
264
for layer in model .sublayers ():
263
265
if hasattr (layer , "rep" ) and not getattr (layer , "is_repped" ):
264
266
layer .rep ()
267
+ return model
268
+
269
+
270
+ def export_single_model (
271
+ model , arch_config , save_path , logger , yaml_path , input_shape = None , quanter = None
272
+ ):
273
+
274
+ model = dynamic_to_static (model , arch_config , logger , input_shape )
265
275
266
276
if quanter is None :
267
- paddle .jit .save (model , save_path )
277
+ paddle_version = version .parse (paddle .__version__ )
278
+ if paddle_version >= version .parse (
279
+ "3.0.0b2"
280
+ ) or paddle_version == version .parse ("0.0.0" ):
281
+ save_path = os .path .dirname (save_path )
282
+ for enable_pir in [True , False ]:
283
+ if not enable_pir :
284
+ save_path_no_pir = os .path .join (save_path , "no_pir" , "inference" )
285
+ model .forward .rollback ()
286
+ with paddle .pir_utils .OldIrGuard ():
287
+ model = dynamic_to_static (
288
+ model , arch_config , logger , input_shape
289
+ )
290
+ paddle .jit .save (model , save_path_no_pir )
291
+ shutil .copy (
292
+ yaml_path ,
293
+ os .path .join (
294
+ os .path .join (save_path , "no_pir" ),
295
+ os .path .basename (yaml_path ),
296
+ ),
297
+ )
298
+ else :
299
+ save_path_pir = os .path .join (save_path , "pir" , "inference" )
300
+ paddle .jit .save (model , save_path_pir )
301
+ shutil .copy (
302
+ yaml_path ,
303
+ os .path .join (
304
+ os .path .join (save_path , "pir" ), os .path .basename (yaml_path )
305
+ ),
306
+ )
307
+ if os .path .exists (yaml_path ):
308
+ os .remove (yaml_path )
309
+ else :
310
+ paddle .jit .save (model , save_path )
268
311
else :
269
312
quanter .save_quantized_model (model , save_path )
270
313
logger .info ("inference model is saved to {}" .format (save_path ))
@@ -362,19 +405,22 @@ def export(config, base_model=None, save_path=None):
362
405
input_shape = rec_rs [0 ]["ABINetRecResizeImg" ]["image_shape" ] if rec_rs else None
363
406
else :
364
407
input_shape = None
365
-
408
+ dump_infer_config ( config , yaml_path , logger )
366
409
if arch_config ["algorithm" ] in [
367
410
"Distillation" ,
368
411
]: # distillation model
369
412
archs = list (arch_config ["Models" ].values ())
370
413
for idx , name in enumerate (model .model_name_list ):
371
414
sub_model_save_path = os .path .join (save_path , name , "inference" )
372
415
export_single_model (
373
- model .model_list [idx ], archs [idx ], sub_model_save_path , logger
416
+ model .model_list [idx ],
417
+ archs [idx ],
418
+ sub_model_save_path ,
419
+ logger ,
420
+ yaml_path ,
374
421
)
375
422
else :
376
423
save_path = os .path .join (save_path , "inference" )
377
424
export_single_model (
378
- model , arch_config , save_path , logger , input_shape = input_shape
425
+ model , arch_config , save_path , logger , yaml_path , input_shape = input_shape
379
426
)
380
- dump_infer_config (config , yaml_path , logger )
0 commit comments