Skip to content

Instantly share code, notes, and snippets.

@blogle
Created December 4, 2014 23:02
Show Gist options
  • Save blogle/30e0e88ceb963f6557f7 to your computer and use it in GitHub Desktop.
Save blogle/30e0e88ceb963f6557f7 to your computer and use it in GitHub Desktop.
bvka_nn.py
import numpy as np
import networkx as nx
import heapq
from collections import defaultdict
def distance(u, v):
return np.sum((u - v)**2)
class KDTree(object):
def __init__(self, data, index=None, depth=0):
"""
Creates a recursive space partitioning DataStructure where each node
splits the dimension at the median of that axis. Similar to a BST,
provides O(n log n) creation and O(log n) queries.
Args:
data (np.array): dataset with shape n, k (n obs, k dim).
Optional
index (np.array): Index corresponding to each node in data
if left empty, the data is zero indexed.
depth (int) : This determines the axis in which to first
partition on e.g 0 -> x, 1 -> y, 2 -> z
Notes:
http://en.wikipedia.org/wiki/K-d_tree
http://en.wikipedia.org/wiki/K-d_tree#mediaviewer/File:KDTree-animation.gif
"""
# Build index at top level
if type(index) == type(None):
index = np.arange(data.shape[0])
self.n = None
self.k = None
self.idx = None
self.node = None
self.axis = None
self.left = None
self.right = None
self.children = None
self._build(data, index, depth)
def _build(self, data, index, depth):
"""Recursively builds the child nodes of the KDTree"""
# If there is data to partition create nodes
if data[index].size:
# Store the dimensions of the data and the axis to partition on
self.n, self.k = data[index].shape
self.axis = (self.k + depth) % self.k
# list of nodes beneath this node
self.children = index
# Find the index of the data sorted on the current axis
# and the midpoint in which to partition
idx_data = np.column_stack((data[index], index))
sort_ax = idx_data[np.argsort(idx_data[:, self.axis]), -1].astype(int)
partition = sort_ax.size / 2
# Node index and data
self.idx = sort_ax[partition]
self.node = data[self.idx]
# Build the branches, partitioning on the next axis
self.left = KDTree(data, sort_ax[ : partition], depth+1)
self.right = KDTree(data, sort_ax[partition+1:], depth+1)
def near_branch(self, point):
"""Returns the branch nearest the input point"""
if point[self.axis] < self.node[self.axis]:
return self.left
return self.right
def far_branch(self, point):
"""Returns the branch furthest the input point"""
if self.near_branch(point) == self.left:
return self.right
return self.left
def orthogonal_dist(self, point):
"""computes the distance from a point to the partition"""
orth_point = np.copy(point)
orth_point[self.axis] = self.node[self.axis]
return distance(point, self.node)
def query(self, point, best=None):
"""Find the nearest neighbor of point in KDTree"""
# Dead end backtrack up the tree
if self.node is None:
return best
# Initialize best
if best is None:
best = (self.idx, self.node)
# check if current node is closer than best
if distance(self.node, point) < distance(best[1], point):
best = (self.idx, self.node)
# continue traversing the tree
best = self.near_branch(point).query(point, best)
# traverse the away branch if the orthogonal distance is less than best
if self.orthogonal_dist(point) < distance(best[1], point):
best = self.far_branch(point).query(point, best)
return best
def query_subset(self, point, subset):
"""Find the nearest neighbor of point in subset"""
subset_vec = np.zeros(self.n)
subset_vec[subset] = 1
return self._query_subset(point, subset_vec, None)
def _query_subset(self, point, subset, best=None):
"""Recursively implements constrained nearest neighbor search"""
# Dead end backtrack up the tree
if np.all(self.node == None):
return best
# Initialize node vectors
idx_vec = np.empty_like(subset)
child_vec = np.empty_like(subset)
idx_vec[:] = child_vec[:] = 0
idx_vec[self.idx] = child_vec[self.children] = 1
# if point in subset, try to update best
if np.dot(idx_vec, subset) != 0:
# if closer than current best, or best is none update
# is_closer is a thunk to prevent '__getitem__' error
is_closer = lambda: distance(self.node, point) < distance(best[1], point)
if np.all(best == None) or is_closer():
best = (self.idx, self.node)
near = self.near_branch(point)
far = self.far_branch(point)
# check the near branch, if its nodes intersect with the queried subset
# otherwise move to the away branch
if np.dot(child_vec, subset) > 0:
best = near._query_subset(point, subset, best)
else:
best = far._query_subset(point, subset, best)
# validate best, by ensuring closer point doesn't exist just beyond
# partition if best still has yet to be found also look
# into this further branch
if (np.all(best != None) and self.orthogonal_dist(point) <
distance(best[1], point)) or np.all(best == None):
best = far._query_subset(point, subset, best)
return best
class PriorityQueue(object):
def __init__(self):
"""
Queue implementing highest-priority-in first-out.
Note:
Priority is cost based, therefore smaller values are prioritized
over larger values.
"""
self._queue = []
self._index = 0
def push(self, item, priority):
"""
Push an item into the queue.
Args:
item (obj): Item to be stored in the queue
priority (Num): Priority in which item will be retrieved from the queue
"""
heapq.heappush(self._queue, (priority, self._index, item))
self._index += 1
def pop(self):
"""
Removes the highest priority item from the queue
Returns:
obj: item with highest priority
"""
return heapq.heappop(self._queue)[-1]
def merge(self, other):
"""
Given another queue, consumes each item in it
and pushes the item and its priority into its own queue
Args:
other (PriorityQueue): Queue to be merged
"""
while other._queue:
priority,i,item = heapq.heappop(other._queue)
self.push(item, priority)
def top(self):
"""
Allows peek at top item in the queue without removing it
Returns:
obj: if the queue is not empty otherwise None
"""
try:
return self._queue[0][-1]
except:
return None
def bvka_mst_edges(G, assume_connected=False, pos='coords'):
V = set(G.nodes(data=False))
pos = np.row_stack(nx.get_node_attributes(G, pos).values())
kdtree = KDTree(pos)
subgraphs = nx.utils.UnionFind()
# This could be swapped for a defaultdict if preferred
queues = defaultdict(PriorityQueue)
for v in V:
# Todo restrict this further to connected edges
vm, _ = kdtree.query_subset(pos[v], list(V - {v}))
dm = distance(pos[v], pos[vm])
root = subgraphs[v]
queues[root].push((v, vm), dm)
Et = []
while len(Et) != len(V) - 1:
Ep = PriorityQueue()
for C in set(map(subgraphs.__getitem__, subgraphs.parents.values())):
(v, vm) = queues[C].top()
component_set = [child for child, parent
in subgraphs.parents.iteritems()
if parent == C]
disjoint_nodes = list(V - set(component_set))
while vm in component_set:
queues[C].pop()
um, _ = kdtree.query_subset(pos[v], disjoint_nodes)
dm = distance(pos[v], pos[vm])
queues[C].push((v, um), dm)
(v, vm) = queues[C].top()
dm = distance(pos[v], pos[vm])
Ep.push((v, vm, dm), dm)
while Ep._queue:
(um, vm, dm) = Ep.pop()
component_i, component_j = subgraphs[um], subgraphs[vm]
if component_i != component_j:
# add the edge and merge the queues
Et += [(um, vm)]
subgraphs.union(um, vm)
if component_i == subgraphs[um]:
major, minor = component_i, component_j
else:
minor, major = component_i, component_j
queues[major].merge(queues[minor])
del(queues[minor])
return Et
@patrafter1999
Copy link

Hi Blogle,
It's an awesome work. Much cleaner than scipy.spatial.KDTree. I didn't test this code yet. Hopefully it's better than scipy version as you shown in the comparison graph. I'm trying to modify your code a little so that I can remove some of the nodes as I wish. Do you have any license on this code?

Cheers,
Sean

@blogle
Copy link
Author

blogle commented Jul 8, 2020

@patrafter1999
5 years later I am stumbling across this message :(
To you and anyone else coming across this gist, please feel free to use the above code as you wish

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment