Skip to content

Instantly share code, notes, and snippets.

@kevinfjiang
Created November 26, 2023 06:35
Show Gist options
  • Save kevinfjiang/0871845a26d487b80c240b972d578978 to your computer and use it in GitHub Desktop.
Save kevinfjiang/0871845a26d487b80c240b972d578978 to your computer and use it in GitHub Desktop.
Implementation of KNN Laesa in Python by F. Moreno-Seco et al.
from typing import Callable, Sequence, Iterator
import random
import heapq
class HeapKey[T]:
"""Heap wrapper cause I dislike the python heap API"""
def __init__(self, data: Sequence[T], key: Callable[[T], float], min_heap: bool=True):
"""Builds heap in O(n) time with heapify, all other are log(n)"""
self.key = key if min_heap else lambda entry: -key(entry)
self.min_heap = min_heap
self.heap = [(self.key(entry), entry) for entry in data]
heapq.heapify(self.heap)
def push(self, entry: T):
heapq.heappush(self.heap, (self.key(entry), entry))
def peak(self) -> T:
if not self:
raise ValueError("Heap must not be empty to peak.")
return self.heap[0][1]
def peak_val(self) -> float:
if not self:
raise ValueError("Heap must not be empty to peak.")
return self.heap[0][0] if self.min_heap else -self.heap[0][0]
def pop(self) -> T:
if not self:
raise ValueError("Heap must not be empty to peak.")
return heapq.heappop(self.heap)[1]
def __len__(self):
return len(self.heap)
def truncate(self, size: int):
"""Pops off elements from the head until it reaches the desired size"""
while self and len(self) > size:
self.pop()
def __iter__(self) -> Iterator[T]:
"""This iterator will consume the heap entirely"""
while self:
yield self.pop()
class KnnLaesa[T]:
"""Computes approximate Nearest Neighbor with linear preprocessing.
References
----------
.. [1] M. L. Mico, J. Oncina, and E. Vidal,
“A new version of the nearest-neighbour approximating and eliminating search algorithm (AESA) with linear preprocessing time and memory requirements,”
Pattern Recognition Letters, vol. 15, no. 1, pp. 9-17, Jan. 1994, doi: 10.1016/0167-8655(94)90095-7.
.. [2] F. Moreno-Seco, L. Mico, and J. Oncina,
“A modification of the LAESA algorithm for approximated k-NN classification,”
Pattern Recognition Letters, vol. 24, no. 1, pp. 47-53, Jan. 2003, doi: 10.1016/S0167-8655(02)00187-3.
Parameters
----------
candidates : list[T]
List of candidates for a neighbor. Referred to as the set of all points $M$.
distance : Callable[[T, T], float]
Distance metric between candidates T, referred to as $d$ for a metric space.
num_bases : int, optional
To limit the number of inter-candidate distance calculated, we only compute the
inter-candidate distance distances for the bases in the preprocessing. Selecting
`num_bases` of bases is done by maximizing the distances between, by default 25
k_neighbors : int, option
Number of nearest neighbors to find, by default 5
"""
def __init__(self, candidates: list[T], distance: Callable[[T, T], float], num_bases: int=25, k_neighbors: int =5):
self.dist = distance
self.candidates = candidates
self.num_candidates = len(candidates)
self.num_bases = num_bases
self.k_neighbors = k_neighbors
assert 0 < self.num_bases, "Must be positive integer"
assert 0 < self.k_neighbors, "Must be positive integer of neighbors"
# Used for LAESA aglorithim, we compute the distance between every point to the
# base candidates so we can narrow our search with the traingle inequality
self.base_indices = [random.choice(range(self.num_candidates))] # arbitrary
self.base_dist = [[0 for _ in range(self.num_candidates)] for _ in range(num_bases)]
lower_bounds = [0 for _ in range(self.num_candidates)]
for i in range(num_bases):
current_base = candidates[self.base_indices[i]]
max_dist_index = 0
for j in range(self.num_candidates):
if j in self.base_indices:
continue
self.base_dist[i][j] = self.dist(current_base, candidates[j])
lower_bounds[j] += self.base_dist[i][j]
# We want the next base to be as far from the others as possible
if lower_bounds[j] > lower_bounds[max_dist_index]:
max_dist_index = j
self.base_indices.append(max_dist_index)
self.base_indices.pop() # Removes last base as we don't compute distances
def predict(self, target: T) -> list[T]:
target_dist = [self.dist(target, self.candidates[p]) for p in self.base_indices]
def compute_lb(j: int) -> float:
"""Computes highest lb using the triangle inequality and the bases."""
return max(abs(target_dist[i] - self.base_dist[i][j]) for i in range(self.num_bases))
lower_bounds = [compute_lb(j) for j in range(self.num_candidates)]
# We assume our lowerbounds total ordering is approximately correct
# The heap ensures that all further lower bounds are greater than the best dist
non_bases = (i for i in range(self.num_candidates) if i not in self.base_indices)
lb_heap = HeapKey(non_bases, key=lambda i: lower_bounds[i])
knn_neighbors = HeapKey(self.base_indices, key=lambda i: self.dist(self.candidates[i], target), min_heap=False)
while lb_heap and (lb_heap.peak_val() <= knn_neighbors.peak_val() or len(knn_neighbors) < self.k_neighbors):
knn_neighbors.push(lb_heap.pop())
knn_neighbors.truncate(self.k_neighbors)
knn_neighbors.truncate(self.k_neighbors)
return [self.candidates[i] for i in reversed(list(knn_neighbors))]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment