Instantly share code, notes, and snippets.

# jinyu121/get_anchor.py

Last active March 5, 2024 02:36
Show Gist options
• Save jinyu121/e530dc9767d8f83c08f3582c71a5cbc8 to your computer and use it in GitHub Desktop.
YOLO2 Get Anchors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 # -*- coding: utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import numpy as np import os import random from tqdm import tqdm import sklearn.cluster as cluster def iou(x, centroids): dists = [] for centroid in centroids: c_w, c_h = centroid w, h = x if c_w >= w and c_h >= h: dist = w * h / (c_w * c_h) elif c_w >= w and c_h <= h: dist = w * c_h / (w * h + (c_w - w) * c_h) elif c_w <= w and c_h >= h: dist = c_w * h / (w * h + c_w * (c_h - h)) else: # means both w,h are bigger than c_w and c_h respectively dist = (c_w * c_h) / (w * h) dists.append(dist) return np.array(dists) def avg_iou(x, centroids): n, d = x.shape sums = 0. for i in range(x.shape[0]): # note IOU() will return array which contains IoU for each centroid and X[i] # slightly ineffective, but I am too lazy sums += max(iou(x[i], centroids)) return sums / n def write_anchors_to_file(centroids, distance, anchor_file): anchors = centroids * 416 / 32 # I do not know whi it is 416/32 anchors = [str(i) for i in anchors.ravel()] print( "\n", "Cluster Result:\n", "Clusters:", len(centroids), "\n", "Average IoU:", distance, "\n", "Anchors:\n", ", ".join(anchors) ) with open(anchor_file, 'w') as f: f.write(", ".join(anchors)) f.write('\n%f\n' % distance) def k_means(x, n_clusters, eps): init_index = [random.randrange(x.shape[0]) for _ in range(n_clusters)] centroids = x[init_index] d = old_d = [] iterations = 0 diff = 1e10 c, dim = centroids.shape while True: iterations += 1 d = np.array([1 - iou(i, centroids) for i in x]) if len(old_d) > 0: diff = np.sum(np.abs(d - old_d)) print('diff = %f' % diff) if diff < eps or iterations > 1000: print("Number of iterations took = %d" % iterations) print("Centroids = ", centroids) return centroids # assign samples to centroids belonging_centroids = np.argmin(d, axis=1) # calculate the new centroids centroid_sums = np.zeros((c, dim), np.float) for i in range(belonging_centroids.shape[0]): centroid_sums[belonging_centroids[i]] += x[i] for j in range(c): centroids[j] = centroid_sums[j] / np.sum(belonging_centroids == j) old_d = d.copy() def get_file_content(fnm): with open(fnm) as f: return [line.strip() for line in f] def main(args): print("Reading Data ...") file_list = [] for f in args.file_list: file_list.extend(get_file_content(f)) data = [] for one_file in tqdm(file_list): one_file = one_file.replace('images', 'labels') \ .replace('JPEGImages', 'labels') \ .replace('.png', '.txt') \ .replace('.jpg', '.txt') for line in get_file_content(one_file): clazz, xx, yy, w, h = line.split() data.append([float(w),float(h)]) data = np.array(data) if args.engine.startswith("sklearn"): if args.engine == "sklearn": km = cluster.KMeans(n_clusters=args.num_clusters, tol=args.tol, verbose=True) elif args.engine == "sklearn-mini": km = cluster.MiniBatchKMeans(n_clusters=args.num_clusters, tol=args.tol, verbose=True) km.fit(data) result = km.cluster_centers_ # distance = km.inertia_ / data.shape[0] distance = avg_iou(data, result) else: result = k_means(data, args.num_clusters, args.tol) distance = avg_iou(data, result) write_anchors_to_file(result, distance, args.output) if "__main__" == __name__: parser = argparse.ArgumentParser() parser.add_argument('file_list', nargs='+', help='TrainList') parser.add_argument('--num_clusters', '-n', default=5, type=int, help='Number of Clusters') parser.add_argument('--output', '-o', default='../results/anchor.txt', type=str, help='Result Output File') parser.add_argument('--tol', '-t', default=0.005, type=float, help='Tolerate') parser.add_argument('--engine', '-m', default='sklearn', type=str, choices=['original', 'sklearn', 'sklearn-mini'], help='Method to use') args = parser.parse_args() main(args)

### jinyu121 commented Oct 17, 2017

anchors差别可能还和数据有关，我在自己的标注图片上跑的，用同一种算法得到的结果差别不大（和随机初始值有关，但是结果也就差0.1不到），但是不同算法得到的就不一样了

engine result
sklearn 11.22762995, 10.73226759, 10.68771405, 9.1692398, 7.452993003, 6.555998014, 6.299477413, 4.831884219, 3.57714225, 3.068625678
original 10.83653525, 10.64363488, 10.54793825, 7.023316928, 6.781054383, 6.024040211, 6.005671164, 4.968083672, 3.417490904, 1.448678908

### PythonImageDeveloper commented Mar 12, 2018

Hi ,
this file generate 10 values of anchors , i have question about these values , as we have 5 anchors and this generator generate 10 values, more likely a first two of 10 values related to first anchor box , right ? if so , what are means of these two values ? W , H for first anchors for aspect ratio and scale for that anchor?

### jinyu121 commented Mar 28, 2018

@zeynali

The 10 values can be grouped as 5 pairs. For example, `11.22762995, 10.73226759, 10.68771405, 9.1692398, 7.452993003, 6.555998014, 6.299477413, 4.831884219, 3.57714225, 3.068625678` means `(11.22762995, 10.73226759), (10.68771405, 9.1692398), (7.452993003, 6.555998014), (6.299477413, 4.831884219), (3.57714225, 3.068625678)`

In my view, the values is H and W in some scale. (So we can just multiply or add them to the output of the net)

### muzi1012 commented Apr 17, 2018

@muzi1012

The txt file can generated by this file. Each file contains multi lines, each line is a full path of one image.

For example, if you have the training list(s) like this:

``````001
002
003
``````

and

``````101
102
103
``````

After the processing by `voc_label.py`, you may get files like

train_part_1.txt

``````path_to_voc/VOC2007/JPEGImages/001.png
path_to_voc/VOC2007/JPEGImages/002.png
path_to_voc/VOC2007/JPEGImages/003.png
....
``````

train_part_2.txt

``````path_to_voc/VOC2007/JPEGImages/101.png
path_to_voc/VOC2007/JPEGImages/102.png
path_to_voc/VOC2007/JPEGImages/103.png
....
``````

Then, you can use `python ./get_anchor.py train_part_1.txt train_part_2.txt` to get anchors.