Created
November 26, 2023 06:35
-
-
Save kevinfjiang/0871845a26d487b80c240b972d578978 to your computer and use it in GitHub Desktop.
Implementation of KNN Laesa in Python by F. Moreno-Seco et al.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Callable, Sequence, Iterator | |
import random | |
import heapq | |
class HeapKey[T]: | |
"""Heap wrapper cause I dislike the python heap API""" | |
def __init__(self, data: Sequence[T], key: Callable[[T], float], min_heap: bool=True): | |
"""Builds heap in O(n) time with heapify, all other are log(n)""" | |
self.key = key if min_heap else lambda entry: -key(entry) | |
self.min_heap = min_heap | |
self.heap = [(self.key(entry), entry) for entry in data] | |
heapq.heapify(self.heap) | |
def push(self, entry: T): | |
heapq.heappush(self.heap, (self.key(entry), entry)) | |
def peak(self) -> T: | |
if not self: | |
raise ValueError("Heap must not be empty to peak.") | |
return self.heap[0][1] | |
def peak_val(self) -> float: | |
if not self: | |
raise ValueError("Heap must not be empty to peak.") | |
return self.heap[0][0] if self.min_heap else -self.heap[0][0] | |
def pop(self) -> T: | |
if not self: | |
raise ValueError("Heap must not be empty to peak.") | |
return heapq.heappop(self.heap)[1] | |
def __len__(self): | |
return len(self.heap) | |
def truncate(self, size: int): | |
"""Pops off elements from the head until it reaches the desired size""" | |
while self and len(self) > size: | |
self.pop() | |
def __iter__(self) -> Iterator[T]: | |
"""This iterator will consume the heap entirely""" | |
while self: | |
yield self.pop() | |
class KnnLaesa[T]: | |
"""Computes approximate Nearest Neighbor with linear preprocessing. | |
References | |
---------- | |
.. [1] M. L. Mico, J. Oncina, and E. Vidal, | |
“A new version of the nearest-neighbour approximating and eliminating search algorithm (AESA) with linear preprocessing time and memory requirements,” | |
Pattern Recognition Letters, vol. 15, no. 1, pp. 9-17, Jan. 1994, doi: 10.1016/0167-8655(94)90095-7. | |
.. [2] F. Moreno-Seco, L. Mico, and J. Oncina, | |
“A modification of the LAESA algorithm for approximated k-NN classification,” | |
Pattern Recognition Letters, vol. 24, no. 1, pp. 47-53, Jan. 2003, doi: 10.1016/S0167-8655(02)00187-3. | |
Parameters | |
---------- | |
candidates : list[T] | |
List of candidates for a neighbor. Referred to as the set of all points $M$. | |
distance : Callable[[T, T], float] | |
Distance metric between candidates T, referred to as $d$ for a metric space. | |
num_bases : int, optional | |
To limit the number of inter-candidate distance calculated, we only compute the | |
inter-candidate distance distances for the bases in the preprocessing. Selecting | |
`num_bases` of bases is done by maximizing the distances between, by default 25 | |
k_neighbors : int, option | |
Number of nearest neighbors to find, by default 5 | |
""" | |
def __init__(self, candidates: list[T], distance: Callable[[T, T], float], num_bases: int=25, k_neighbors: int =5): | |
self.dist = distance | |
self.candidates = candidates | |
self.num_candidates = len(candidates) | |
self.num_bases = num_bases | |
self.k_neighbors = k_neighbors | |
assert 0 < self.num_bases, "Must be positive integer" | |
assert 0 < self.k_neighbors, "Must be positive integer of neighbors" | |
# Used for LAESA aglorithim, we compute the distance between every point to the | |
# base candidates so we can narrow our search with the traingle inequality | |
self.base_indices = [random.choice(range(self.num_candidates))] # arbitrary | |
self.base_dist = [[0 for _ in range(self.num_candidates)] for _ in range(num_bases)] | |
lower_bounds = [0 for _ in range(self.num_candidates)] | |
for i in range(num_bases): | |
current_base = candidates[self.base_indices[i]] | |
max_dist_index = 0 | |
for j in range(self.num_candidates): | |
if j in self.base_indices: | |
continue | |
self.base_dist[i][j] = self.dist(current_base, candidates[j]) | |
lower_bounds[j] += self.base_dist[i][j] | |
# We want the next base to be as far from the others as possible | |
if lower_bounds[j] > lower_bounds[max_dist_index]: | |
max_dist_index = j | |
self.base_indices.append(max_dist_index) | |
self.base_indices.pop() # Removes last base as we don't compute distances | |
def predict(self, target: T) -> list[T]: | |
target_dist = [self.dist(target, self.candidates[p]) for p in self.base_indices] | |
def compute_lb(j: int) -> float: | |
"""Computes highest lb using the triangle inequality and the bases.""" | |
return max(abs(target_dist[i] - self.base_dist[i][j]) for i in range(self.num_bases)) | |
lower_bounds = [compute_lb(j) for j in range(self.num_candidates)] | |
# We assume our lowerbounds total ordering is approximately correct | |
# The heap ensures that all further lower bounds are greater than the best dist | |
non_bases = (i for i in range(self.num_candidates) if i not in self.base_indices) | |
lb_heap = HeapKey(non_bases, key=lambda i: lower_bounds[i]) | |
knn_neighbors = HeapKey(self.base_indices, key=lambda i: self.dist(self.candidates[i], target), min_heap=False) | |
while lb_heap and (lb_heap.peak_val() <= knn_neighbors.peak_val() or len(knn_neighbors) < self.k_neighbors): | |
knn_neighbors.push(lb_heap.pop()) | |
knn_neighbors.truncate(self.k_neighbors) | |
knn_neighbors.truncate(self.k_neighbors) | |
return [self.candidates[i] for i in reversed(list(knn_neighbors))] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment