1
+ import random
2
+ import argparse
3
+ import numpy as np
4
+
5
+ from preprocessing import parse_annotation
6
+ import json
7
+
8
+ argparser = argparse .ArgumentParser ()
9
+
10
+ argparser .add_argument (
11
+ '-c' ,
12
+ '--conf' ,
13
+ default = 'config.json' ,
14
+ help = 'path to configuration file' )
15
+
16
+ argparser .add_argument (
17
+ '-a' ,
18
+ '--anchors' ,
19
+ default = 5 ,
20
+ help = 'number of anchors to use' )
21
+
22
+ def IOU (ann , centroids ):
23
+ w , h = ann
24
+ similarities = []
25
+
26
+ for centroid in centroids :
27
+ c_w , c_h = centroid
28
+
29
+ if c_w >= w and c_h >= h :
30
+ similarity = w * h / (c_w * c_h )
31
+ elif c_w >= w and c_h <= h :
32
+ similarity = w * c_h / (w * h + (c_w - w )* c_h )
33
+ elif c_w <= w and c_h >= h :
34
+ similarity = c_w * h / (w * h + c_w * (c_h - h ))
35
+ else : #means both w,h are bigger than c_w and c_h respectively
36
+ similarity = (c_w * c_h )/ (w * h )
37
+ similarities .append (similarity ) # will become (k,) shape
38
+
39
+ return np .array (similarities )
40
+
41
+ def avg_IOU (anns , centroids ):
42
+ n ,d = anns .shape
43
+ sum = 0.
44
+
45
+ for i in range (anns .shape [0 ]):
46
+ sum += max (IOU (anns [i ], centroids ))
47
+
48
+ return sum / n
49
+
50
+ def print_anchors (centroids ):
51
+ anchors = centroids .copy ()
52
+
53
+ widths = anchors [:, 0 ]
54
+ sorted_indices = np .argsort (widths )
55
+
56
+ r = "anchors: ["
57
+ for i in sorted_indices [:- 1 ]:
58
+ r += '%0.2f,%0.2f, ' % (anchors [i ,0 ], anchors [i ,1 ])
59
+
60
+ #there should not be comma after last anchor, that's why
61
+ r += '%0.2f,%0.2f' % (anchors [sorted_indices [- 1 :],0 ], anchors [sorted_indices [- 1 :],1 ])
62
+ r += "]"
63
+
64
+ print r
65
+
66
+ def run_kmeans (ann_dims , anchor_num ):
67
+ ann_num = ann_dims .shape [0 ]
68
+ iterations = 0
69
+ prev_assignments = np .ones (ann_num )* (- 1 )
70
+ iteration = 0
71
+ old_distances = np .zeros ((ann_num , anchor_num ))
72
+
73
+ indices = [random .randrange (ann_dims .shape [0 ]) for i in range (anchor_num )]
74
+ centroids = ann_dims [indices ]
75
+ anchor_dim = ann_dims .shape [1 ]
76
+
77
+ while True :
78
+ distances = []
79
+ iteration += 1
80
+ for i in range (ann_num ):
81
+ d = 1 - IOU (ann_dims [i ], centroids )
82
+ distances .append (d )
83
+ distances = np .array (distances ) # distances.shape = (ann_num, anchor_num)
84
+
85
+ print "iteration {}: dists = {}" .format (iteration , np .sum (np .abs (old_distances - distances )))
86
+
87
+ #assign samples to centroids
88
+ assignments = np .argmin (distances ,axis = 1 )
89
+
90
+ if (assignments == prev_assignments ).all () :
91
+ return centroids
92
+
93
+ #calculate new centroids
94
+ centroid_sums = np .zeros ((anchor_num , anchor_dim ), np .float )
95
+ for i in range (ann_num ):
96
+ centroid_sums [assignments [i ]]+= ann_dims [i ]
97
+ for j in range (anchor_num ):
98
+ < << << << HEAD
99
+ centroids [j ] = centroid_sums [j ]/ (np .sum (assignments == j ) + 1e-6 )
100
+ == == == =
101
+ centroids [j ] = centroid_sums [j ]/ (np .sum (assignments == j ))
102
+ >> >> >> > 4843 db4a6c199941d8e4e57bdd86af7d33882c61
103
+
104
+ prev_assignments = assignments .copy ()
105
+ old_distances = distances .copy ()
106
+
107
+ def main (argv ):
108
+ config_path = args .conf
109
+ num_anchors = args .anchors
110
+
111
+ with open (config_path ) as config_buffer :
112
+ config = json .loads (config_buffer .read ())
113
+
114
+ train_imgs , train_labels = parse_annotation (config ['train' ]['train_annot_folder' ],
115
+ config ['train' ]['train_image_folder' ],
116
+ config ['model' ]['labels' ])
117
+
118
+ grid_w = config ['model' ]['input_size' ]/ 32
119
+ grid_h = config ['model' ]['input_size' ]/ 32
120
+
121
+ # run k_mean to find the anchors
122
+ annotation_dims = []
123
+ for image in train_imgs :
124
+ cell_w = image ['width' ]/ grid_w
125
+ cell_h = image ['height' ]/ grid_h
126
+
127
+ for obj in image ['object' ]:
128
+ relative_w = (float (obj ['xmax' ]) - float (obj ['xmin' ]))/ cell_w
129
+ relatice_h = (float (obj ["ymax" ]) - float (obj ['ymin' ]))/ cell_h
130
+ annotation_dims .append (map (float , (relative_w ,relatice_h )))
131
+ annotation_dims = np .array (annotation_dims )
132
+
133
+ centroids = run_kmeans (annotation_dims , num_anchors )
134
+
135
+ # write anchors to file
136
+ print '\n average IOU for' , num_anchors , 'anchors:' , '%0.2f' % avg_IOU (annotation_dims , centroids )
137
+ print_anchors (centroids )
138
+
139
+ if __name__ == '__main__' :
140
+ args = argparser .parse_args ()
141
+ main (args )
0 commit comments