@@ -118,7 +118,7 @@ def create_predictor(model_path,
118
118
elif device == "cpu" :
119
119
config .disable_gpu ()
120
120
elif device == "npu" :
121
- config .enable_npu ( )
121
+ config .enable_custom_device ( 'npu' )
122
122
elif device == "xpu" :
123
123
config .enable_xpu ()
124
124
else :
@@ -174,11 +174,10 @@ def main():
174
174
random .seed (args .seed )
175
175
np .random .seed (args .seed )
176
176
cfg = get_config (args .config_file , args .opt , show = True )
177
- predictor , config = create_predictor (args .model_path , args .device , args .run_mode ,
178
- args .batch_size , args .min_subgraph_size ,
179
- args .use_dynamic_shape , args .trt_min_shape ,
180
- args .trt_max_shape , args .trt_opt_shape ,
181
- args .trt_calib_mode )
177
+ predictor , config = create_predictor (
178
+ args .model_path , args .device , args .run_mode , args .batch_size ,
179
+ args .min_subgraph_size , args .use_dynamic_shape , args .trt_min_shape ,
180
+ args .trt_max_shape , args .trt_opt_shape , args .trt_calib_mode )
182
181
input_handles = [
183
182
predictor .get_input_handle (name )
184
183
for name in predictor .get_input_names ()
@@ -225,15 +224,15 @@ def main():
225
224
elif model_type == "cyclegan" :
226
225
import auto_log
227
226
logger = get_logger (name = 'ppgan' )
228
-
227
+
229
228
size = data ['A' ].shape
230
229
pid = os .getpid ()
231
230
auto_logger = auto_log .AutoLogger (
232
231
model_name = args .model_type ,
233
232
model_precision = args .run_mode ,
234
233
batch_size = args .batch_size ,
235
234
data_shape = size ,
236
- save_path = args .output_path + 'auto_log.lpg' ,
235
+ save_path = args .output_path + 'auto_log.lpg' ,
237
236
inference_config = config ,
238
237
pids = pid ,
239
238
process_name = None ,
@@ -254,7 +253,11 @@ def main():
254
253
save_image (
255
254
image_numpy ,
256
255
os .path .join (args .output_path , "cyclegan/{}.png" .format (i )))
257
- logger .info ("Inference succeeded! The inference result has been saved in {}" .format (os .path .join (args .output_path , "cyclegan/{}.png" .format (i ))))
256
+ logger .info (
257
+ "Inference succeeded! The inference result has been saved in {}"
258
+ .format (
259
+ os .path .join (args .output_path ,
260
+ "cyclegan/{}.png" .format (i ))))
258
261
auto_logger .times .end (stamp = True )
259
262
auto_logger .report ()
260
263
metric_file = os .path .join (args .output_path , "cyclegan/metric.txt" )
0 commit comments