Skip to content

Commit bd586f4

Browse files
Fix batch predict of cls and rec (#1089)
* fixbug_bs=1 of predict_cls\rec
1 parent 01ea6f9 commit bd586f4

File tree

2 files changed

+64
-18
lines changed

2 files changed

+64
-18
lines changed

deploy/python/predict_cls.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from python.preprocess import create_operators
2828
from python.postprocess import build_postprocess
2929

30-
3130
class ClsPredictor(Predictor):
3231
def __init__(self, config):
3332
super().__init__(config["Global"])
@@ -59,21 +58,47 @@ def predict(self, images):
5958
input_tensor.copy_from_cpu(image)
6059
self.paddle_predictor.run()
6160
batch_output = output_tensor.copy_to_cpu()
61+
if self.postprocess is not None:
62+
batch_output = self.postprocess(batch_output)
6263
return batch_output
6364

6465

6566
def main(config):
6667
cls_predictor = ClsPredictor(config)
6768
image_list = get_image_list(config["Global"]["infer_imgs"])
6869

69-
assert config["Global"]["batch_size"] == 1
70-
for idx, image_file in enumerate(image_list):
71-
img = cv2.imread(image_file)[:, :, ::-1]
72-
output = cls_predictor.predict(img)
73-
output = cls_predictor.postprocess(output, [image_file])
74-
print(output)
75-
return
70+
batch_imgs = []
71+
batch_names = []
72+
cnt = 0
73+
for idx, img_path in enumerate(image_list):
74+
img = cv2.imread(img_path)
75+
if img is None:
76+
logger.warning(
77+
"Image file failed to read and has been skipped. The path: {}".
78+
format(img_path))
79+
else:
80+
img = img[:, :, ::-1]
81+
batch_imgs.append(img)
82+
img_name = os.path.basename(img_path)
83+
batch_names.append(img_name)
84+
cnt += 1
7685

86+
if cnt % config["Global"]["batch_size"] == 0 or (idx + 1) == len(image_list):
87+
if len(batch_imgs) == 0:
88+
continue
89+
90+
batch_results = cls_predictor.predict(batch_imgs)
91+
for number, result_dict in enumerate(batch_results):
92+
filename = batch_names[number]
93+
clas_ids = result_dict["class_ids"]
94+
scores_str = "[{}]".format(", ".join("{:.2f}".format(
95+
r) for r in result_dict["scores"]))
96+
label_names = result_dict["label_names"]
97+
print("{}:\tclass id(s): {}, score(s): {}, label_name(s): {}".
98+
format(filename, clas_ids, scores_str, label_names))
99+
batch_imgs = []
100+
batch_names = []
101+
return
77102

78103
if __name__ == "__main__":
79104
args = config.parse_args()

deploy/python/predict_rec.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,48 @@ def predict(self, images, feature_normalize=True):
5454
input_tensor.copy_from_cpu(image)
5555
self.paddle_predictor.run()
5656
batch_output = output_tensor.copy_to_cpu()
57-
57+
5858
if feature_normalize:
5959
feas_norm = np.sqrt(
6060
np.sum(np.square(batch_output), axis=1, keepdims=True))
6161
batch_output = np.divide(batch_output, feas_norm)
62-
62+
63+
if self.postprocess is not None:
64+
batch_output = self.postprocess(batch_output)
6365
return batch_output
6466

6567

6668
def main(config):
6769
rec_predictor = RecPredictor(config)
6870
image_list = get_image_list(config["Global"]["infer_imgs"])
6971

70-
assert config["Global"]["batch_size"] == 1
71-
for idx, image_file in enumerate(image_list):
72-
batch_input = []
73-
img = cv2.imread(image_file)[:, :, ::-1]
74-
output = rec_predictor.predict(img)
75-
if rec_predictor.postprocess is not None:
76-
output = rec_predictor.postprocess(output)
77-
print(output)
72+
batch_imgs = []
73+
batch_names = []
74+
cnt = 0
75+
for idx, img_path in enumerate(image_list):
76+
img = cv2.imread(img_path)
77+
if img is None:
78+
logger.warning(
79+
"Image file failed to read and has been skipped. The path: {}".
80+
format(img_path))
81+
else:
82+
img = img[:, :, ::-1]
83+
batch_imgs.append(img)
84+
img_name = os.path.basename(img_path)
85+
batch_names.append(img_name)
86+
cnt += 1
87+
88+
if cnt % config["Global"]["batch_size"] == 0 or (idx + 1) == len(image_list):
89+
if len(batch_imgs) == 0:
90+
continue
91+
92+
batch_results = rec_predictor.predict(batch_imgs)
93+
for number, result_dict in enumerate(batch_results):
94+
filename = batch_names[number]
95+
print("{}:\t{}".format(filename, result_dict))
96+
batch_imgs = []
97+
batch_names = []
98+
7899
return
79100

80101

0 commit comments

Comments
 (0)