diff --git a/core/yolov3.py b/core/yolov3.py index 3ea09c264..322a9444e 100644 --- a/core/yolov3.py +++ b/core/yolov3.py @@ -96,8 +96,9 @@ def __build_nework(self, input_data): def decode(self, conv_output, anchors, stride): """ - return tensor of shape [batch_size, output_size, output_size, anchor_per_scale, 5 + num_classes] - contains (x, y, w, h, score, probability) + return tensor of shape [batch_size, output_size, + output_size, anchor_per_scale, 5 + num_classes] + contains (x, y, w, h, score, probability) """ conv_shape = tf.shape(conv_output) @@ -105,28 +106,39 @@ def decode(self, conv_output, anchors, stride): output_size = conv_shape[1] anchor_per_scale = len(anchors) - conv_output = tf.reshape(conv_output, (batch_size, output_size, output_size, anchor_per_scale, 5 + self.num_class)) + STEP = 5 + self.num_class + conv_sigmoid = tf.sigmoid(conv_output) + conv_sig_dx = conv_sigmoid[:, :, :, 0::STEP] + conv_sig_dy = conv_sigmoid[:, :, :, 1::STEP] + conv_raw_dw = conv_output[:, :, :, 2::STEP] + conv_raw_dh = conv_output[:, :, :, 3::STEP] - conv_raw_dxdy = conv_output[:, :, :, :, 0:2] - conv_raw_dwdh = conv_output[:, :, :, :, 2:4] - conv_raw_conf = conv_output[:, :, :, :, 4:5] - conv_raw_prob = conv_output[:, :, :, :, 5: ] - y = tf.tile(tf.range(output_size, dtype=tf.int32)[:, tf.newaxis], [1, output_size]) - x = tf.tile(tf.range(output_size, dtype=tf.int32)[tf.newaxis, :], [output_size, 1]) + x = tf.tile(tf.range(output_size, dtype=tf.int32)[ + tf.newaxis, tf.newaxis, :, tf.newaxis], + [batch_size, output_size, 1, anchor_per_scale]) + y = tf.tile(tf.range(output_size, dtype=tf.int32)[ + tf.newaxis, :, tf.newaxis, tf.newaxis], + [batch_size, 1, output_size, anchor_per_scale]) - xy_grid = tf.concat([x[:, :, tf.newaxis], y[:, :, tf.newaxis]], axis=-1) - xy_grid = tf.tile(xy_grid[tf.newaxis, :, :, tf.newaxis, :], [batch_size, 1, 1, anchor_per_scale, 1]) - xy_grid = tf.cast(xy_grid, tf.float32) + x = tf.cast(x, tf.float32) + y = tf.cast(y, tf.float32) - pred_xy = (tf.sigmoid(conv_raw_dxdy) + xy_grid) * stride - pred_wh = (tf.exp(conv_raw_dwdh) * anchors) * stride - pred_xywh = tf.concat([pred_xy, pred_wh], axis=-1) + pred_x = (conv_sig_dx + x) * stride + pred_y = (conv_sig_dy + y) * stride + pred_w = tf.exp(conv_raw_dw) * anchors[:, 0] * stride + pred_h = tf.exp(conv_raw_dh) * anchors[:, 1] * stride - pred_conf = tf.sigmoid(conv_raw_conf) - pred_prob = tf.sigmoid(conv_raw_prob) - return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1) + result = tf.concat([ + tf.concat([ + pred_x[:, :, :, i:i+1], pred_y[:, :, :, i:i+1], + pred_w[:, :, :, i:i+1], pred_h[:, :, :, i:i+1], + conv_sigmoid[:, :, :, 4 + STEP * i: STEP * (i + 1)]], axis=-1) + for i in range(anchor_per_scale)], axis=-1) + + return tf.reshape(result, + (batch_size, output_size, output_size, anchor_per_scale, STEP)) def focal(self, target, actual, alpha=1, gamma=2): focal_loss = alpha * tf.pow(tf.abs(target - actual), gamma) diff --git a/freeze_graph.py b/freeze_graph.py index e4a17c791..65fbd244c 100644 --- a/freeze_graph.py +++ b/freeze_graph.py @@ -17,7 +17,7 @@ pb_file = "./yolov3_coco.pb" ckpt_file = "./checkpoint/yolov3_coco_demo.ckpt" -output_node_names = ["input/input_data", "pred_sbbox/concat_2", "pred_mbbox/concat_2", "pred_lbbox/concat_2"] +output_node_names = ["input/input_data", "pred_sbbox/Reshape", "pred_mbbox/Reshape", "pred_lbbox/Reshape"] with tf.name_scope('input'): input_data = tf.placeholder(dtype=tf.float32, name='input_data') diff --git a/image_demo.py b/image_demo.py index c85356c01..aad914f8a 100644 --- a/image_demo.py +++ b/image_demo.py @@ -17,7 +17,7 @@ import tensorflow as tf from PIL import Image -return_elements = ["input/input_data:0", "pred_sbbox/concat_2:0", "pred_mbbox/concat_2:0", "pred_lbbox/concat_2:0"] +return_elements = ["input/input_data:0", "pred_sbbox/Reshape:0", "pred_mbbox/Reshape:0", "pred_lbbox/Reshape:0"] pb_file = "./yolov3_coco.pb" image_path = "./docs/images/road.jpeg" num_classes = 80 diff --git a/video_demo.py b/video_demo.py index 3bafc2d33..577283b2b 100644 --- a/video_demo.py +++ b/video_demo.py @@ -19,7 +19,7 @@ from PIL import Image -return_elements = ["input/input_data:0", "pred_sbbox/concat_2:0", "pred_mbbox/concat_2:0", "pred_lbbox/concat_2:0"] +return_elements = ["input/input_data:0", "pred_sbbox/Reshape:0", "pred_mbbox/Reshape:0", "pred_lbbox/Reshape:0"] pb_file = "./yolov3_coco.pb" video_path = "./docs/images/road.mp4" # video_path = 0