Skip to content

Instantly share code, notes, and snippets.

@cryzed
Created June 20, 2018 17:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cryzed/f25823ea594a2cdd8a41eb81e370e662 to your computer and use it in GitHub Desktop.
Save cryzed/f25823ea594a2cdd8a41eb81e370e662 to your computer and use it in GitHub Desktop.
import argparse
import collections
import itertools
import os
import sys
import time
try:
import cPickle as pickle
except ImportError:
import pickle
import PIL.Image
import numpy as np
import vlfeat
import scipy.cluster.vq
import scipy.spatial.distance
import matplotlib.pyplot as plt
PAGES_PATH = os.path.join('data', 'pages')
GT_PATH = os.path.join('data', 'GT')
IFS_MATCH_IMAGES_PATH = os.path.join('ifs_match_images')
MATCH_IMAGES_PATH = os.path.join('match_images')
CODE_BOOK_PATH = os.path.join('data', 'codebook.bin')
SPATIAL_PYRAMID_TYPES = ['L', 'R', 'G', 'GL', 'GR', 'LR', 'GLR']
CELL_MARGIN_TYPES = ['none', 'horizontal', 'vertical', 'both']
argument_parser = argparse.ArgumentParser()
argument_parser.add_argument('--step-size', '-s', type=int, default=15)
argument_parser.add_argument('--cell-size', '-c', type=int, default=3)
argument_parser.add_argument('--centroids', '-C', type=int, default=40)
argument_parser.add_argument('--k-means-iterations', '-k', type=int, default=20)
argument_parser.add_argument('--distance-metric', choices=['cityblock', 'cosine', 'euclidean'], default='cosine')
argument_parser.add_argument('--spatial-pyramid-type', '-S', choices=SPATIAL_PYRAMID_TYPES, default='LR')
argument_parser.add_argument('--pages', '-p', type=int, default=1)
argument_parser.add_argument('--accumulator-percentile', '-a', type=float, default=95.0)
argument_parser.add_argument('--use-ifs', '-I', action='store_true')
argument_parser.add_argument('--use-accumulator', '-A', action='store_true')
argument_parser.add_argument('--save-images', '-sa', action='store_true')
argument_parser.add_argument('--verbose', action='store_true')
argument_parser.add_argument('--cell-margin', choices=CELL_MARGIN_TYPES, default='horizontal')
SpatialPyramid = collections.namedtuple('SpatialPyramid', ['global_', 'left', 'right'])
def makedirs(name, mode=0777, exist_ok=False):
if not exist_ok:
return os.makedirs(name, mode)
# Taken from Python 3
try:
os.makedirs(name, mode)
except OSError:
# Cannot rely on checking for EEXIST, since the operating system
# could give priority to other errors like EACCES or EROFS
if not exist_ok or not os.path.isdir(name):
raise
def load_gtp_file(path):
entries = collections.defaultdict(list)
with open(path) as file:
for line in (line for line in (line.strip() for line in file) if line):
x1, y1, x2, y2, word = line.split()
entries[word].append((int(x1), int(y1), int(x2), int(y2)))
return entries
def load_codebook(path):
input_file = open(path, 'r')
code_book = np.fromfile(input_file, dtype='float32')
code_book = np.reshape(code_book, (4096, 128))
return code_book
def make_spatial_pyramid(data, length, type_='GLR'):
count = len(data)
left_index = int(np.floor(count / 2))
right_index = int(np.ceil(count / 2))
if type_ == 'L':
data = [], data[:left_index], []
elif type_ == 'R':
data = [], [], data[right_index:]
elif type_ == 'G':
data = data, [], []
elif type_ == 'GL':
data = data, data[:left_index], []
elif type_ == 'GR':
data = data, [], data[right_index:]
elif type_ == 'LR':
data = [], data[:left_index], data[right_index:]
elif type_ == 'GLR':
data = data, data[:left_index], data[right_index:]
else:
raise ValueError('unknown spatial pyramid type: %r' % type_)
spatial_pyramid = SpatialPyramid(*(np.bincount(datum, minlength=length) for datum in data))
return np.concatenate(spatial_pyramid)
def load_corpus(page_names):
defaultdict_factory = lambda: collections.defaultdict(defaultdict_factory)
corpus = collections.defaultdict(defaultdict_factory)
offset = 0
images = []
corpus_gtp = collections.defaultdict(list)
for page_name in page_names:
corpus['pages'][page_name]['offset'] = offset
# Load page image
image_path = os.path.join(PAGES_PATH, '%s.png' % page_name)
corpus['pages'][page_name]['image_path'] = image_path
image = PIL.Image.open(image_path)
corpus['pages'][page_name]['image'] = image
images.append(image)
# Load page GTP
gtp_path = os.path.join(GT_PATH, '%s.gtp' % page_name)
corpus['pages'][page_name]['gtp_path'] = gtp_path
gtp = load_gtp_file(gtp_path)
corpus['pages'][page_name]['gtp'] = gtp
# Update global corpus GTP with current offset
for word, coordinates in gtp.items():
for x1, y1, x2, y2 in coordinates:
corpus_gtp[word].append((x1 + offset, y1, x2 + offset, y2))
offset += image.width
# Create Corpus image by concatenating page images horizontally
width = sum(image.width for image in images)
max_height = max(image.height for image in images)
corpus_image = PIL.Image.new(images[0].mode, (width, max_height))
x_offset = 0
for image in images:
corpus_image.paste(image, (x_offset, 0))
x_offset += image.width
corpus['gtp'] = corpus_gtp
corpus['image'] = corpus_image
corpus['data'] = np.array(corpus_image, dtype='float32')
return corpus
def pre_main(arguments):
load_codebook(os.path.join('data', 'codebook.bin'))
page_names = [os.path.splitext(filename)[0] for filename in sorted(os.listdir(PAGES_PATH))[:arguments.pages]]
corpus = load_corpus(page_names)
results1 = collections.OrderedDict()
results2 = collections.OrderedDict()
for accumulator_percentile in range(0, 105, 5):
print accumulator_percentile
arguments.use_ifs = True
arguments.use_accumulator = True
arguments.accumulator_percentile = accumulator_percentile
start = time.time()
mean_average_precision = main(arguments, corpus)
duration = int(time.time() - start)
results1[accumulator_percentile] = mean_average_precision
results2[accumulator_percentile] = duration
plt.plot(range(len(results1)), results1.values(), 'o')
plt.xlabel('Accumulator Percentile')
plt.ylabel('Mean Average Precision')
plt.xticks(range(len(results1)), results1.keys())
plt.grid(True)
plt.ylim(0, 1)
plt.tight_layout()
plt.show()
plt.plot(range(len(results2)), results2.values(), 'o')
plt.xlabel('Accumulator Percentile')
plt.ylabel('Runtime')
plt.xticks(range(len(results2)), results2.keys())
plt.grid(True)
plt.tight_layout()
plt.show()
def main(arguments, corpus):
# Calculate SIFT data for corpus
frames, descriptors = vlfeat.vl_dsift(
corpus['image'] / corpus['data'].max(), step=arguments.step_size, size=arguments.cell_size,
fast=True, float_descriptors=True)
# Find all frames and descriptors contained inside word boundaries (minus a cell margin of cell_size * 2)
cell_margin = 2 * arguments.cell_size
words_frames = []
words_descriptors = []
previous_frame_index = 0
word_data_indices = collections.OrderedDict()
word_coordinates = collections.OrderedDict()
for word, coordinates in corpus['gtp'].items():
# Filter word frames within word bounding box
for variation, (x1, y1, x2, y2) in enumerate(coordinates):
if arguments.cell_margin == 'none':
mask = (
(frames[:, 0] >= x1) & (frames[:, 1] >= y1) &
(frames[:, 0] <= x2) & (frames[:, 1] <= y2))
elif arguments.cell_margin == 'horizontal':
mask = (
(frames[:, 0] >= x1 + cell_margin) & (frames[:, 1] >= y1) &
(frames[:, 0] <= x2 - cell_margin) & (frames[:, 1] <= y2))
elif arguments.cell_margin == 'vertical':
mask = (
(frames[:, 0] >= x1) & (frames[:, 1] >= y1 + cell_margin) &
(frames[:, 0] <= x2) & (frames[:, 1] <= y2 - cell_margin))
elif arguments.cell_margin == 'both':
mask = (
(frames[:, 0] >= x1 + cell_margin) & (frames[:, 1] >= y1 + cell_margin) &
(frames[:, 0] <= x2 - cell_margin) & (frames[:, 1] <= y2 - cell_margin))
else:
raise RuntimeError('dude what the fuck are you doing')
# Get matching frames/desc for the word
word_frames = frames[mask]
words_frames.append(word_frames)
words_descriptors.append(descriptors[mask])
# Count how many frames are contained inside the bounding box
frame_count = word_frames.shape[0]
# Note at which index and how many (following) frames/descs are part of a word
key = word, variation
word_data_indices[key] = previous_frame_index, frame_count
word_coordinates[key] = x1, y1, x2, y2
previous_frame_index += frame_count
words_frames = np.concatenate(words_frames)
words_descriptors = np.concatenate(words_descriptors)
if arguments.centroids == 4096:
code_book = load_codebook(CODE_BOOK_PATH)
labels, _ = scipy.cluster.vq.vq(words_descriptors, code_book)
else:
# Calculate labels
_, labels = scipy.cluster.vq.kmeans2(
words_descriptors, arguments.centroids, iter=arguments.k_means_iterations, minit='points')
# Word -> labels mapping
# noinspection PyArgumentList
word_labels = collections.OrderedDict(
(key, labels[start:start + length]) for key, (start, length) in word_data_indices.items())
# Create (word, variation) -> spatial pyramid mapping
# noinspection PyArgumentList
spatial_pyramids = collections.OrderedDict(
(key, make_spatial_pyramid(labels, arguments.centroids, arguments.spatial_pyramid_type))
for key, labels in word_labels.items())
# Create IFS database
ifs_height = len(spatial_pyramids.values()[0])
ifs = [set() for count in range(ifs_height)]
for word_index, spatial_pyramid in enumerate(spatial_pyramids.values()):
for index, count in enumerate(spatial_pyramid):
if count:
ifs[index].add(word_index)
# Create word index -> variation set mapping
word_variation_indices = collections.defaultdict(set)
for word_index, (word, variation) in enumerate(spatial_pyramids.keys()):
word_variation_indices[word].add(word_index)
# Find query in IFS
spatial_pyramids_values = spatial_pyramids.values()
word_coordinates_values = word_coordinates.values()
average_precisions = []
average_recalls = []
for word_index, ((word, variation), query) in enumerate(spatial_pyramids.items()):
# Skip words with no findable duplicates in the IFS database
appearances = len(word_variation_indices[word]) - 1
if not appearances:
if arguments.verbose:
print >> sys.stderr, 'No duplicate appearances for (%s, %d)!' % (word, variation)
continue
if arguments.use_ifs:
ifs_candidate_indices = list(itertools.chain(*(ifs[index] for index, count in enumerate(query) if count)))
candidate_indices = set(ifs_candidate_indices)
if not candidate_indices:
if arguments.verbose:
print >> sys.stderr, 'No candidates for (%s, %d) after IFS!' % (word, variation)
average_precisions.append(0)
average_recalls.append(0)
continue
if arguments.use_accumulator:
# noinspection PyArgumentList
accumulator = collections.Counter(ifs_candidate_indices)
# No candidates left after having applied the IFS
if not accumulator:
if arguments.verbose:
print >> sys.stderr, 'No candidates for (%s, %d) after IFS + Accumulator!' % (word, variation)
average_precisions.append(0)
average_recalls.append(0)
continue
most_common = accumulator.most_common()
rankings = sorted(set(accumulator.values()))
percentile_ranking = rankings[max(0, int(len(rankings) * arguments.accumulator_percentile / 100.0) - 1)]
candidate_indices = set(
index for index, count in
list(itertools.takewhile(lambda item: item[1] >= percentile_ranking, most_common)))
else:
candidate_indices = set(range(len(spatial_pyramids)))
candidate_indices -= {word_index}
if not candidate_indices:
if arguments.verbose:
print >> sys.stderr, 'No candidates for (%s, %d)' % (word, variation)
average_precisions.append(0)
average_recalls.append(0)
continue
candidate_pyramids = np.array([spatial_pyramids_values[index] for index in candidate_indices])
query = query.reshape((1, query.shape[0]))
distances = scipy.spatial.distance.cdist(query, candidate_pyramids, metric=arguments.distance_metric)[0]
# Translate index in distance array to index of candidate
distances_indices = range(distances.shape[0])
distance_index_to_candidate_index = {
distance_index: candidate_index for distance_index, candidate_index in
zip(distances_indices, candidate_indices)}
distances_sorted_indices = np.argsort(distances)
sorted_candidate_indices = [
distance_index_to_candidate_index[distance_index] for distance_index in distances_sorted_indices]
hits = [1 if index in word_variation_indices[word] else 0 for index in sorted_candidate_indices]
true_positives = sum(hits)
# Calculate accumulated hits at index
hits_at_k = []
current_hits = 0
for hit in hits:
if hit:
current_hits += 1
hits_at_k.append(current_hits)
average_precision = sum(
(current_hits / float(index)) * hit for index, (hit, current_hits) in
enumerate(zip(hits, hits_at_k), start=1)) / float(appearances)
average_precisions.append(average_precision)
average_recalls.append(true_positives / float(appearances))
if arguments.save_images:
match_images_path = os.path.join(MATCH_IMAGES_PATH, '%s_%d' % (word, variation))
makedirs(match_images_path, exist_ok=True)
coordinates = word_coordinates_values[word_index]
corpus['image'].crop(coordinates).save(os.path.join(match_images_path, '0_original.png'))
for rank, candidate_word_index in enumerate(sorted_candidate_indices, start=1):
coordinates = word_coordinates_values[candidate_word_index]
path = os.path.join(match_images_path, 'candidate_%d.png' % rank)
corpus['image'].crop(coordinates).save(path)
# print 'Word %s (Variation: %d): %.2f%%' % (word, variation, average_precision * 100)
print 'Mean Recall: %f' % (np.mean(average_recalls) * 100)
mean_average_precision = np.mean(average_precisions)
print 'Mean Average Precision: %f' % (mean_average_precision * 100)
return mean_average_precision
if __name__ == '__main__':
arguments = argument_parser.parse_args()
pre_main(arguments)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment