Skip to content

Commit c8874d7

Browse files
authored
[GCU] Support inference for GCU (#14142)
1 parent fbba217 commit c8874d7

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

tools/infer/utility.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ def init_args():
4141
parser.add_argument("--use_xpu", type=str2bool, default=False)
4242
parser.add_argument("--use_npu", type=str2bool, default=False)
4343
parser.add_argument("--use_mlu", type=str2bool, default=False)
44+
parser.add_argument(
45+
"--use_gcu",
46+
type=str2bool,
47+
default=False,
48+
help="Use Enflame GCU(General Compute Unit)",
49+
)
4450
parser.add_argument("--ir_optim", type=str2bool, default=True)
4551
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
4652
parser.add_argument("--min_subgraph_size", type=int, default=15)
@@ -298,6 +304,34 @@ def create_predictor(args, mode, logger):
298304
config.enable_custom_device("mlu")
299305
elif args.use_xpu:
300306
config.enable_xpu(10 * 1024 * 1024)
307+
elif args.use_gcu: # for Enflame GCU(General Compute Unit)
308+
assert paddle.device.is_compiled_with_custom_device("gcu"), (
309+
"Args use_gcu cannot be set as True while your paddle "
310+
"is not compiled with gcu! \nPlease try: \n"
311+
"\t1. Install paddle-custom-gcu to run model on GCU. \n"
312+
"\t2. Set use_gcu as False in args to run model on CPU."
313+
)
314+
import paddle_custom_device.gcu.passes as gcu_passes
315+
316+
gcu_passes.setUp()
317+
if args.precision == "fp16":
318+
config.enable_custom_device(
319+
"gcu", 0, paddle.inference.PrecisionType.Half
320+
)
321+
gcu_passes.set_exp_enable_mixed_precision_ops(config)
322+
else:
323+
config.enable_custom_device("gcu")
324+
325+
if paddle.framework.use_pir_api():
326+
config.enable_new_ir(True)
327+
config.enable_new_executor(True)
328+
kPirGcuPasses = gcu_passes.inference_passes(
329+
use_pir=True, name="PaddleOCR"
330+
)
331+
config.enable_custom_passes(kPirGcuPasses, True)
332+
else:
333+
pass_builder = config.pass_builder()
334+
gcu_passes.append_passes_for_legacy_ir(pass_builder, "PaddleOCR")
301335
else:
302336
config.disable_gpu()
303337
if args.enable_mkldnn:
@@ -314,7 +348,8 @@ def create_predictor(args, mode, logger):
314348
# enable memory optim
315349
config.enable_memory_optim()
316350
config.disable_glog_info()
317-
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
351+
if not args.use_gcu: # for Enflame GCU(General Compute Unit)
352+
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
318353
config.delete_pass("matmul_transpose_reshape_fuse_pass")
319354
if mode == "rec" and args.rec_algorithm == "SRN":
320355
config.delete_pass("gpu_cpu_map_matmul_v2_to_matmul_pass")

tools/program.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def merge_config(config, opts):
115115
return config
116116

117117

118-
def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
118+
def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False, use_gcu=False):
119119
"""
120120
Log error and exit when set use_gpu=true in paddlepaddle
121121
cpu version.
@@ -154,6 +154,9 @@ def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
154154
if use_mlu and not paddle.device.is_compiled_with_mlu():
155155
print(err.format("use_mlu", "mlu", "mlu", "use_mlu"))
156156
sys.exit(1)
157+
if use_gcu and not paddle.device.is_compiled_with_custom_device("gcu"):
158+
print(err.format("use_gcu", "gcu", "gcu", "use_gcu"))
159+
sys.exit(1)
157160
except Exception as e:
158161
pass
159162

@@ -799,6 +802,7 @@ def preprocess(is_train=False):
799802
use_xpu = config["Global"].get("use_xpu", False)
800803
use_npu = config["Global"].get("use_npu", False)
801804
use_mlu = config["Global"].get("use_mlu", False)
805+
use_gcu = config["Global"].get("use_gcu", False)
802806

803807
alg = config["Architecture"]["algorithm"]
804808
assert alg in [
@@ -853,9 +857,11 @@ def preprocess(is_train=False):
853857
device = "npu:{0}".format(os.getenv("FLAGS_selected_npus", 0))
854858
elif use_mlu:
855859
device = "mlu:{0}".format(os.getenv("FLAGS_selected_mlus", 0))
860+
elif use_gcu: # Use Enflame GCU(General Compute Unit)
861+
device = "gcu:{0}".format(os.getenv("FLAGS_selected_gcus", 0))
856862
else:
857863
device = "gpu:{}".format(dist.ParallelEnv().dev_id) if use_gpu else "cpu"
858-
check_device(use_gpu, use_xpu, use_npu, use_mlu)
864+
check_device(use_gpu, use_xpu, use_npu, use_mlu, use_gcu)
859865

860866
device = paddle.set_device(device)
861867

0 commit comments

Comments
 (0)