Skip to content

Commit d80dc58

Browse files
committed
add mot_pose_demo;sych with det benchmark codes
1 parent 2e61ae9 commit d80dc58

File tree

5 files changed

+47
-35
lines changed

5 files changed

+47
-35
lines changed

README_cn.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ PaddleDetection模块化地实现了多种主流目标检测算法,提供了
1717

1818
<div align="center">
1919
<img src="static/docs/images/football.gif" width='800'/>
20+
<img src="docs/images/mot_pose_demo_640x360.gif" width='800'/>
2021
</div>
2122

2223
### 产品动态

deploy/python/keypoint_det_unite_infer.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import os
16-
1716
from PIL import Image
1817
import cv2
1918
import numpy as np
@@ -52,7 +51,7 @@ def get_person_from_rect(images, results):
5251
org_rects = []
5352
for rect in valid_rects:
5453
rect_image, new_rect, org_rect = expand_crop(images, rect)
55-
if rect_image is None:
54+
if rect_image is None or rect_image.size == 0:
5655
continue
5756
image_buff.append([rect_image, new_rect])
5857
org_rects.append(org_rect)
@@ -113,13 +112,13 @@ def topdown_unite_predict_video(detector, topdown_keypoint_detector, camera_id):
113112
os.makedirs(FLAGS.output_dir)
114113
out_path = os.path.join(FLAGS.output_dir, video_name)
115114
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
116-
index = 1
115+
index = 0
117116
while (1):
118117
ret, frame = capture.read()
119118
if not ret:
120119
break
121-
print('detect frame:%d' % (index))
122120
index += 1
121+
print('detect frame:%d' % (index))
123122

124123
frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
125124
results = detector.predict(frame2, FLAGS.det_threshold)
@@ -136,7 +135,7 @@ def topdown_unite_predict_video(detector, topdown_keypoint_detector, camera_id):
136135
keypoint_res = {}
137136
keypoint_res['keypoint'] = [
138137
np.vstack(keypoint_vector), np.vstack(score_vector)
139-
]
138+
] if len(keypoint_vector) > 0 else [[], []]
140139
keypoint_res['bbox'] = rect_vecotr
141140
im = draw_pose(
142141
frame,
@@ -189,8 +188,6 @@ def main():
189188
# predict from image
190189
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
191190
topdown_unite_predict(detector, topdown_keypoint_detector, img_list)
192-
detector.det_times.info(average=True)
193-
topdown_keypoint_detector.det_times.info(average=True)
194191

195192

196193
if __name__ == '__main__':

deploy/python/keypoint_infer.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from keypoint_visualize import draw_pose
2929
from paddle.inference import Config
3030
from paddle.inference import create_predictor
31-
from utils import argsparser, Timer, get_current_memory_mb, LoggerHelper
31+
from utils import argsparser, Timer, get_current_memory_mb
32+
from benchmark_utils import PaddleInferBenchmark
3233
from infer import get_test_images, print_arguments
3334

3435
# Global dictionary
@@ -66,7 +67,7 @@ def __init__(self,
6667
cpu_threads=1,
6768
enable_mkldnn=False):
6869
self.pred_config = pred_config
69-
self.predictor = load_predictor(
70+
self.predictor, self.config = load_predictor(
7071
model_dir,
7172
run_mode=run_mode,
7273
min_subgraph_size=self.pred_config.min_subgraph_size,
@@ -129,15 +130,15 @@ def predict(self, image, threshold=0.5, warmup=0, repeats=1):
129130
MaskRCNN's results include 'masks': np.ndarray:
130131
shape: [N, im_h, im_w]
131132
'''
132-
self.det_times.preprocess_time.start()
133+
self.det_times.preprocess_time_s.start()
133134
inputs = self.preprocess(image)
134135
np_boxes, np_masks = None, None
135136
input_names = self.predictor.get_input_names()
136137

137138
for i in range(len(input_names)):
138139
input_tensor = self.predictor.get_input_handle(input_names[i])
139140
input_tensor.copy_from_cpu(inputs[input_names[i]])
140-
self.det_times.preprocess_time.end()
141+
self.det_times.preprocess_time_s.end()
141142
for i in range(warmup):
142143
self.predictor.run()
143144
output_names = self.predictor.get_output_names()
@@ -152,7 +153,7 @@ def predict(self, image, threshold=0.5, warmup=0, repeats=1):
152153
inds_k.copy_to_cpu()
153154
]
154155

155-
self.det_times.inference_time.start()
156+
self.det_times.inference_time_s.start()
156157
for i in range(repeats):
157158
self.predictor.run()
158159
output_names = self.predictor.get_output_names()
@@ -166,12 +167,12 @@ def predict(self, image, threshold=0.5, warmup=0, repeats=1):
166167
masks_tensor.copy_to_cpu(), heat_k.copy_to_cpu(),
167168
inds_k.copy_to_cpu()
168169
]
169-
self.det_times.inference_time.end(repeats=repeats)
170+
self.det_times.inference_time_s.end(repeats=repeats)
170171

171-
self.det_times.postprocess_time.start()
172+
self.det_times.postprocess_time_s.start()
172173
results = self.postprocess(
173174
np_boxes, np_masks, inputs, threshold=threshold)
174-
self.det_times.postprocess_time.end()
175+
self.det_times.postprocess_time_s.end()
175176
self.det_times.img_num += 1
176177
return results
177178

@@ -318,7 +319,7 @@ def load_predictor(model_dir,
318319
# disable feed, fetch OP, needed by zero_copy_run
319320
config.switch_use_feed_fetch_ops(False)
320321
predictor = create_predictor(config)
321-
return predictor
322+
return predictor, config
322323

323324

324325
def predict_image(detector, image_list):
@@ -347,7 +348,8 @@ def predict_video(detector, camera_id):
347348
video_name = 'output.mp4'
348349
else:
349350
capture = cv2.VideoCapture(FLAGS.video_file)
350-
video_name = os.path.basename(os.path.split(FLAGS.video_file)[-1])
351+
video_name = os.path.splitext(os.path.basename(FLAGS.video_file))[
352+
0] + '.mp4'
351353
fps = 30
352354
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
353355
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
@@ -407,9 +409,22 @@ def main():
407409
'gpu_rss': detector.gpu_mem / len(img_list),
408410
'gpu_util': detector.gpu_util * 100 / len(img_list)
409411
}
410-
det_logger = LoggerHelper(
411-
FLAGS, detector.det_times.report(average=True), mems)
412-
det_logger.report()
412+
413+
perf_info = detector.det_times.report(average=True)
414+
model_dir = FLAGS.model_dir
415+
mode = FLAGS.run_mode
416+
model_info = {
417+
'model_name': model_dir.strip('/').split('/')[-1],
418+
'precision': mode.split('_')[-1]
419+
}
420+
data_info = {
421+
'batch_size': 1,
422+
'shape': "dynamic_shape",
423+
'data_num': perf_info['img_num']
424+
}
425+
det_log = PaddleInferBenchmark(detector.config, model_info,
426+
data_info, perf_info, mems)
427+
det_log('KeyPoint')
413428

414429

415430
if __name__ == '__main__':

deploy/python/keypoint_visualize.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@
1919
import math
2020

2121

22-
def map_coco_to_personlab(keypoints):
23-
permute = [0, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3]
24-
return keypoints[:, permute, :]
25-
26-
2722
def draw_pose(imgfile,
2823
results,
2924
visual_thread=0.6,
@@ -39,9 +34,9 @@ def draw_pose(imgfile,
3934
'for example: `pip install matplotlib`.')
4035
raise e
4136

42-
EDGES = [(0, 14), (0, 13), (0, 4), (0, 1), (14, 16), (13, 15), (4, 10),
43-
(1, 7), (10, 11), (7, 8), (11, 12), (8, 9), (4, 5), (1, 2), (5, 6),
44-
(2, 3)]
37+
EDGES = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 7), (6, 8),
38+
(7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), (13, 15),
39+
(14, 16), (11, 12)]
4540
NUM_EDGES = len(EDGES)
4641

4742
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
@@ -52,33 +47,35 @@ def draw_pose(imgfile,
5247

5348
img = cv2.imread(imgfile) if type(imgfile) == str else imgfile
5449
skeletons, scores = results['keypoint']
50+
color_set = results['colors'] if 'colors' in results else None
5551

5652
if 'bbox' in results:
5753
bboxs = results['bbox']
58-
for idx, rect in enumerate(bboxs):
54+
for j, rect in enumerate(bboxs):
5955
xmin, ymin, xmax, ymax = rect
60-
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), colors[0], 1)
56+
color = colors[0] if color_set is None else colors[color_set[j] %
57+
len(colors)]
58+
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color, 1)
6159

6260
canvas = img.copy()
6361
for i in range(17):
64-
rgba = np.array(cmap(1 - i / 17. - 1. / 34))
65-
rgba[0:3] *= 255
6662
for j in range(len(skeletons)):
6763
if skeletons[j][i, 2] < visual_thread:
6864
continue
65+
color = colors[i] if color_set is None else colors[color_set[j] %
66+
len(colors)]
6967
cv2.circle(
7068
canvas,
7169
tuple(skeletons[j][i, 0:2].astype('int32')),
7270
2,
73-
colors[i],
71+
color,
7472
thickness=-1)
7573

7674
to_plot = cv2.addWeighted(img, 0.3, canvas, 0.7, 0)
7775
fig = matplotlib.pyplot.gcf()
7876

7977
stickwidth = 2
8078

81-
skeletons = map_coco_to_personlab(skeletons)
8279
for i in range(NUM_EDGES):
8380
for j in range(len(skeletons)):
8481
edge = EDGES[i]
@@ -96,7 +93,9 @@ def draw_pose(imgfile,
9693
polygon = cv2.ellipse2Poly((int(mY), int(mX)),
9794
(int(length / 2), stickwidth),
9895
int(angle), 0, 360, 1)
99-
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
96+
color = colors[i] if color_set is None else colors[color_set[j] %
97+
len(colors)]
98+
cv2.fillConvexPoly(cur_canvas, polygon, color)
10099
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
101100
if returnimg:
102101
return canvas

docs/images/mot_pose_demo_640x360.gif

22 MB
Loading

0 commit comments

Comments
 (0)