Skip to content

Commit a3e3e92

Browse files
committed
updated reamde
1 parent 70cf2d7 commit a3e3e92

File tree

1 file changed

+141
-0
lines changed

1 file changed

+141
-0
lines changed

gen_anchors.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import random
2+
import argparse
3+
import numpy as np
4+
5+
from preprocessing import parse_annotation
6+
import json
7+
8+
argparser = argparse.ArgumentParser()
9+
10+
argparser.add_argument(
11+
'-c',
12+
'--conf',
13+
default='config.json',
14+
help='path to configuration file')
15+
16+
argparser.add_argument(
17+
'-a',
18+
'--anchors',
19+
default=5,
20+
help='number of anchors to use')
21+
22+
def IOU(ann, centroids):
23+
w, h = ann
24+
similarities = []
25+
26+
for centroid in centroids:
27+
c_w, c_h = centroid
28+
29+
if c_w >= w and c_h >= h:
30+
similarity = w*h/(c_w*c_h)
31+
elif c_w >= w and c_h <= h:
32+
similarity = w*c_h/(w*h + (c_w-w)*c_h)
33+
elif c_w <= w and c_h >= h:
34+
similarity = c_w*h/(w*h + c_w*(c_h-h))
35+
else: #means both w,h are bigger than c_w and c_h respectively
36+
similarity = (c_w*c_h)/(w*h)
37+
similarities.append(similarity) # will become (k,) shape
38+
39+
return np.array(similarities)
40+
41+
def avg_IOU(anns, centroids):
42+
n,d = anns.shape
43+
sum = 0.
44+
45+
for i in range(anns.shape[0]):
46+
sum+= max(IOU(anns[i], centroids))
47+
48+
return sum/n
49+
50+
def print_anchors(centroids):
51+
anchors = centroids.copy()
52+
53+
widths = anchors[:, 0]
54+
sorted_indices = np.argsort(widths)
55+
56+
r = "anchors: ["
57+
for i in sorted_indices[:-1]:
58+
r += '%0.2f,%0.2f, ' % (anchors[i,0], anchors[i,1])
59+
60+
#there should not be comma after last anchor, that's why
61+
r += '%0.2f,%0.2f' % (anchors[sorted_indices[-1:],0], anchors[sorted_indices[-1:],1])
62+
r += "]"
63+
64+
print r
65+
66+
def run_kmeans(ann_dims, anchor_num):
67+
ann_num = ann_dims.shape[0]
68+
iterations = 0
69+
prev_assignments = np.ones(ann_num)*(-1)
70+
iteration = 0
71+
old_distances = np.zeros((ann_num, anchor_num))
72+
73+
indices = [random.randrange(ann_dims.shape[0]) for i in range(anchor_num)]
74+
centroids = ann_dims[indices]
75+
anchor_dim = ann_dims.shape[1]
76+
77+
while True:
78+
distances = []
79+
iteration += 1
80+
for i in range(ann_num):
81+
d = 1 - IOU(ann_dims[i], centroids)
82+
distances.append(d)
83+
distances = np.array(distances) # distances.shape = (ann_num, anchor_num)
84+
85+
print "iteration {}: dists = {}".format(iteration, np.sum(np.abs(old_distances-distances)))
86+
87+
#assign samples to centroids
88+
assignments = np.argmin(distances,axis=1)
89+
90+
if (assignments == prev_assignments).all() :
91+
return centroids
92+
93+
#calculate new centroids
94+
centroid_sums=np.zeros((anchor_num, anchor_dim), np.float)
95+
for i in range(ann_num):
96+
centroid_sums[assignments[i]]+=ann_dims[i]
97+
for j in range(anchor_num):
98+
<<<<<<< HEAD
99+
centroids[j] = centroid_sums[j]/(np.sum(assignments==j) + 1e-6)
100+
=======
101+
centroids[j] = centroid_sums[j]/(np.sum(assignments==j))
102+
>>>>>>> 4843db4a6c199941d8e4e57bdd86af7d33882c61
103+
104+
prev_assignments = assignments.copy()
105+
old_distances = distances.copy()
106+
107+
def main(argv):
108+
config_path = args.conf
109+
num_anchors = args.anchors
110+
111+
with open(config_path) as config_buffer:
112+
config = json.loads(config_buffer.read())
113+
114+
train_imgs, train_labels = parse_annotation(config['train']['train_annot_folder'],
115+
config['train']['train_image_folder'],
116+
config['model']['labels'])
117+
118+
grid_w = config['model']['input_size']/32
119+
grid_h = config['model']['input_size']/32
120+
121+
# run k_mean to find the anchors
122+
annotation_dims = []
123+
for image in train_imgs:
124+
cell_w = image['width']/grid_w
125+
cell_h = image['height']/grid_h
126+
127+
for obj in image['object']:
128+
relative_w = (float(obj['xmax']) - float(obj['xmin']))/cell_w
129+
relatice_h = (float(obj["ymax"]) - float(obj['ymin']))/cell_h
130+
annotation_dims.append(map(float, (relative_w,relatice_h)))
131+
annotation_dims = np.array(annotation_dims)
132+
133+
centroids = run_kmeans(annotation_dims, num_anchors)
134+
135+
# write anchors to file
136+
print '\naverage IOU for', num_anchors, 'anchors:', '%0.2f' % avg_IOU(annotation_dims, centroids)
137+
print_anchors(centroids)
138+
139+
if __name__ == '__main__':
140+
args = argparser.parse_args()
141+
main(args)

0 commit comments

Comments
 (0)