Skip to content

Instantly share code, notes, and snippets.

@Deepayan137
Created November 13, 2017 11:16
Show Gist options
  • Save Deepayan137/d7b93c783f531a13e54ae1fb19421f06 to your computer and use it in GitHub Desktop.
Save Deepayan137/d7b93c783f531a13e54ae1fb19421f06 to your computer and use it in GitHub Desktop.
Implementation of K nearest neighbours on image features
from doctools.cluster.mst import cluster
from doctools.cluster.distance import jaccard, lev, euc
from doctools.parser.convert import page_to_unit
from doctools.parser import webtotrain
from argparse import ArgumentParser
from pprint import pprint
from .dot import as_dot
import json
from functools import partial
import pdb
from doctools.ocr import GravesOCR
import os
from .opts import base_opts
import numpy as np
import matplotlib
matplotlib.use('agg')
from matplotlib import pyplot as plt
import cv2
import pdb
from sklearn.neighbors import KNeighborsClassifier
from collections import Counter
def get_index(a):
uniques = np.unique(np.array(a))
idx = {u:[] for u in uniques}
for i, v in enumerate(a):
idx[v].append(i)
return dict(sorted(idx.items(),key=lambda x: -len(x[1])))
def visualize(path, components):
number_of_subplots = 10
images = os.listdir(path)
for i,v in enumerate(components):
print("cluster no: %d"%i)
number_of_subplots = min(len(v), number_of_subplots)
for j in range(number_of_subplots):
print(os.path.join(path,images[v[j]]))
im = cv2.imread(os.path.join(path,images[v[j]]))
print("subploting [%d/%d]"%(j,len(v)))
print(1,number_of_subplots, j+1)
plt.subplot(1,number_of_subplots, j+1),plt.xticks([]), plt.yticks([])
plt.imshow(im)
plt.savefig(os.path.join(args.output,'cluster_%d.png'%i))
# plt.show()
if __name__ == '__main__':
parser = ArgumentParser()
base_opts(parser)
args = parser.parse_args()
config_file = open(args.config)
config = json.load(config_file)
# Load OCR
print(config["model"])
ocr = GravesOCR(config["model"], config["lookup"])
# Parse Book in and predict
book_name = config["books"][args.book]
neigh = KNeighborsClassifier(n_neighbors=3)
fpath = os.path.join(config["feat_dir"], config["books"][args.book])
feat = np.load(os.path.join(fpath, "feats.npy"))
words=[]
with open(os.path.join(fpath, 'annotation.txt'), 'r') as in_file:
lines = in_file.readlines()
# get the ground truths and image features from annotation and featur files
words = [line.rsplit()[1] for line in lines]
features = [feat[i] for i in range(feat.shape[0])]
# train-test split
test_words = np.array(words[2000:])
test_features = np.array(features[2000:])
# returns the indices of the most frequent words
test_index = get_index(test_words)
indices = sum([v for i,v in test_index.items() if len(v)>=3],[])
# we take only those words which have atleast 3 instances in the test file
test_words = test_words[indices]
test_features = test_features[indices]
# we implement knn classifier
neigh.fit(features[:2000], words[:2000])
y_predict = neigh.predict(test_features)
acc = [1 if y_predict[i]==test_words[i] else 0 for i in range(len(test_words))]
print(sum(acc)/len(test_words))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment