Skip to content

Commit a99e506

Browse files
authored
update npu inference api (#779)
1 parent fb1b085 commit a99e506

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

tools/inference.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def create_predictor(model_path,
118118
elif device == "cpu":
119119
config.disable_gpu()
120120
elif device == "npu":
121-
config.enable_npu()
121+
config.enable_custom_device('npu')
122122
elif device == "xpu":
123123
config.enable_xpu()
124124
else:
@@ -174,11 +174,10 @@ def main():
174174
random.seed(args.seed)
175175
np.random.seed(args.seed)
176176
cfg = get_config(args.config_file, args.opt, show=True)
177-
predictor, config = create_predictor(args.model_path, args.device, args.run_mode,
178-
args.batch_size, args.min_subgraph_size,
179-
args.use_dynamic_shape, args.trt_min_shape,
180-
args.trt_max_shape, args.trt_opt_shape,
181-
args.trt_calib_mode)
177+
predictor, config = create_predictor(
178+
args.model_path, args.device, args.run_mode, args.batch_size,
179+
args.min_subgraph_size, args.use_dynamic_shape, args.trt_min_shape,
180+
args.trt_max_shape, args.trt_opt_shape, args.trt_calib_mode)
182181
input_handles = [
183182
predictor.get_input_handle(name)
184183
for name in predictor.get_input_names()
@@ -225,15 +224,15 @@ def main():
225224
elif model_type == "cyclegan":
226225
import auto_log
227226
logger = get_logger(name='ppgan')
228-
227+
229228
size = data['A'].shape
230229
pid = os.getpid()
231230
auto_logger = auto_log.AutoLogger(
232231
model_name=args.model_type,
233232
model_precision=args.run_mode,
234233
batch_size=args.batch_size,
235234
data_shape=size,
236-
save_path=args.output_path+'auto_log.lpg',
235+
save_path=args.output_path + 'auto_log.lpg',
237236
inference_config=config,
238237
pids=pid,
239238
process_name=None,
@@ -254,7 +253,11 @@ def main():
254253
save_image(
255254
image_numpy,
256255
os.path.join(args.output_path, "cyclegan/{}.png".format(i)))
257-
logger.info("Inference succeeded! The inference result has been saved in {}".format(os.path.join(args.output_path, "cyclegan/{}.png".format(i))))
256+
logger.info(
257+
"Inference succeeded! The inference result has been saved in {}"
258+
.format(
259+
os.path.join(args.output_path,
260+
"cyclegan/{}.png".format(i))))
258261
auto_logger.times.end(stamp=True)
259262
auto_logger.report()
260263
metric_file = os.path.join(args.output_path, "cyclegan/metric.txt")

0 commit comments

Comments
 (0)