Created
November 13, 2017 11:16
-
-
Save Deepayan137/d7b93c783f531a13e54ae1fb19421f06 to your computer and use it in GitHub Desktop.
Implementation of K nearest neighbours on image features
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from doctools.cluster.mst import cluster | |
from doctools.cluster.distance import jaccard, lev, euc | |
from doctools.parser.convert import page_to_unit | |
from doctools.parser import webtotrain | |
from argparse import ArgumentParser | |
from pprint import pprint | |
from .dot import as_dot | |
import json | |
from functools import partial | |
import pdb | |
from doctools.ocr import GravesOCR | |
import os | |
from .opts import base_opts | |
import numpy as np | |
import matplotlib | |
matplotlib.use('agg') | |
from matplotlib import pyplot as plt | |
import cv2 | |
import pdb | |
from sklearn.neighbors import KNeighborsClassifier | |
from collections import Counter | |
def get_index(a): | |
uniques = np.unique(np.array(a)) | |
idx = {u:[] for u in uniques} | |
for i, v in enumerate(a): | |
idx[v].append(i) | |
return dict(sorted(idx.items(),key=lambda x: -len(x[1]))) | |
def visualize(path, components): | |
number_of_subplots = 10 | |
images = os.listdir(path) | |
for i,v in enumerate(components): | |
print("cluster no: %d"%i) | |
number_of_subplots = min(len(v), number_of_subplots) | |
for j in range(number_of_subplots): | |
print(os.path.join(path,images[v[j]])) | |
im = cv2.imread(os.path.join(path,images[v[j]])) | |
print("subploting [%d/%d]"%(j,len(v))) | |
print(1,number_of_subplots, j+1) | |
plt.subplot(1,number_of_subplots, j+1),plt.xticks([]), plt.yticks([]) | |
plt.imshow(im) | |
plt.savefig(os.path.join(args.output,'cluster_%d.png'%i)) | |
# plt.show() | |
if __name__ == '__main__': | |
parser = ArgumentParser() | |
base_opts(parser) | |
args = parser.parse_args() | |
config_file = open(args.config) | |
config = json.load(config_file) | |
# Load OCR | |
print(config["model"]) | |
ocr = GravesOCR(config["model"], config["lookup"]) | |
# Parse Book in and predict | |
book_name = config["books"][args.book] | |
neigh = KNeighborsClassifier(n_neighbors=3) | |
fpath = os.path.join(config["feat_dir"], config["books"][args.book]) | |
feat = np.load(os.path.join(fpath, "feats.npy")) | |
words=[] | |
with open(os.path.join(fpath, 'annotation.txt'), 'r') as in_file: | |
lines = in_file.readlines() | |
# get the ground truths and image features from annotation and featur files | |
words = [line.rsplit()[1] for line in lines] | |
features = [feat[i] for i in range(feat.shape[0])] | |
# train-test split | |
test_words = np.array(words[2000:]) | |
test_features = np.array(features[2000:]) | |
# returns the indices of the most frequent words | |
test_index = get_index(test_words) | |
indices = sum([v for i,v in test_index.items() if len(v)>=3],[]) | |
# we take only those words which have atleast 3 instances in the test file | |
test_words = test_words[indices] | |
test_features = test_features[indices] | |
# we implement knn classifier | |
neigh.fit(features[:2000], words[:2000]) | |
y_predict = neigh.predict(test_features) | |
acc = [1 if y_predict[i]==test_words[i] else 0 for i in range(len(test_words))] | |
print(sum(acc)/len(test_words)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment