@@ -41,6 +41,12 @@ def init_args():
41
41
parser .add_argument ("--use_xpu" , type = str2bool , default = False )
42
42
parser .add_argument ("--use_npu" , type = str2bool , default = False )
43
43
parser .add_argument ("--use_mlu" , type = str2bool , default = False )
44
+ parser .add_argument (
45
+ "--use_gcu" ,
46
+ type = str2bool ,
47
+ default = False ,
48
+ help = "Use Enflame GCU(General Compute Unit)" ,
49
+ )
44
50
parser .add_argument ("--ir_optim" , type = str2bool , default = True )
45
51
parser .add_argument ("--use_tensorrt" , type = str2bool , default = False )
46
52
parser .add_argument ("--min_subgraph_size" , type = int , default = 15 )
@@ -298,6 +304,34 @@ def create_predictor(args, mode, logger):
298
304
config .enable_custom_device ("mlu" )
299
305
elif args .use_xpu :
300
306
config .enable_xpu (10 * 1024 * 1024 )
307
+ elif args .use_gcu : # for Enflame GCU(General Compute Unit)
308
+ assert paddle .device .is_compiled_with_custom_device ("gcu" ), (
309
+ "Args use_gcu cannot be set as True while your paddle "
310
+ "is not compiled with gcu! \n Please try: \n "
311
+ "\t 1. Install paddle-custom-gcu to run model on GCU. \n "
312
+ "\t 2. Set use_gcu as False in args to run model on CPU."
313
+ )
314
+ import paddle_custom_device .gcu .passes as gcu_passes
315
+
316
+ gcu_passes .setUp ()
317
+ if args .precision == "fp16" :
318
+ config .enable_custom_device (
319
+ "gcu" , 0 , paddle .inference .PrecisionType .Half
320
+ )
321
+ gcu_passes .set_exp_enable_mixed_precision_ops (config )
322
+ else :
323
+ config .enable_custom_device ("gcu" )
324
+
325
+ if paddle .framework .use_pir_api ():
326
+ config .enable_new_ir (True )
327
+ config .enable_new_executor (True )
328
+ kPirGcuPasses = gcu_passes .inference_passes (
329
+ use_pir = True , name = "PaddleOCR"
330
+ )
331
+ config .enable_custom_passes (kPirGcuPasses , True )
332
+ else :
333
+ pass_builder = config .pass_builder ()
334
+ gcu_passes .append_passes_for_legacy_ir (pass_builder , "PaddleOCR" )
301
335
else :
302
336
config .disable_gpu ()
303
337
if args .enable_mkldnn :
@@ -314,7 +348,8 @@ def create_predictor(args, mode, logger):
314
348
# enable memory optim
315
349
config .enable_memory_optim ()
316
350
config .disable_glog_info ()
317
- config .delete_pass ("conv_transpose_eltwiseadd_bn_fuse_pass" )
351
+ if not args .use_gcu : # for Enflame GCU(General Compute Unit)
352
+ config .delete_pass ("conv_transpose_eltwiseadd_bn_fuse_pass" )
318
353
config .delete_pass ("matmul_transpose_reshape_fuse_pass" )
319
354
if mode == "rec" and args .rec_algorithm == "SRN" :
320
355
config .delete_pass ("gpu_cpu_map_matmul_v2_to_matmul_pass" )
0 commit comments