Skip to content

Instantly share code, notes, and snippets.

@ngshaohui
Last active April 5, 2020 02:04
Show Gist options
  • Save ngshaohui/988c6538774d38cd0463cdfb45ab7aaa to your computer and use it in GitHub Desktop.
Save ngshaohui/988c6538774d38cd0463cdfb45ab7aaa to your computer and use it in GitHub Desktop.
Balanced KDTree implementation in Typescript
import { KDTree, DistanceFunc } from './KDTree'
import { expect } from 'chai'
import 'mocha'
interface Point {
x: number
y: number
}
const pointList: Point[] = [
{ x: 7, y: 2 },
{ x: 5, y: 4 },
{ x: 9, y: 6 },
{ x: 4, y: 7 },
{ x: 8, y: 1 },
{ x: 2, y: 3 },
]
function distanceFormula(point1: Point, point2: Point): number {
return Math.sqrt(
Math.pow(point1.x - point2.x, 2) + Math.pow(point1.y - point2.y, 2),
)
}
const compFunc: DistanceFunc<Point> = distanceFormula
describe('invalid KDTree', () => {
it('should not allow tree to be created', () => {
const ls = [{}, {}, {}]
function fcn() {
new KDTree(ls)
}
expect(fcn).to.throw(Error, 'Object should not have depth of 0')
})
})
describe('valid 2 dimensional KDTree', () => {
it('should allow empty list', () => {
const tree: KDTree<Point> = new KDTree([])
const queryPoint: Point = { x: 5, y: 5 }
expect(tree.getNearest(queryPoint, compFunc)).to.equal(null)
})
describe('for KDTree built using pointList', () => {
const tree: KDTree<Point> = new KDTree(pointList, ['x', 'y'])
it('should find nearest neighbour', () => {
const queryPoints: Point[] = [
{ x: 5, y: 5 },
{ x: 2, y: 7 },
{ x: 10, y: 10 },
{ x: 8, y: 1 },
]
const solutions: Point[] = [
{ x: 5, y: 4 },
{ x: 4, y: 7 },
{ x: 9, y: 6 },
{ x: 8, y: 1 },
]
queryPoints.forEach((queryPoint: Point, idx: number) => {
expect(tree.getNearest(queryPoint, compFunc)).to.eql(solutions[idx])
})
})
it('should find nn for queryPoint directly on point', () => {
expect(tree.getNearest({ x: 8, y: 1 }, compFunc)).to.eql({ x: 8, y: 1 })
})
})
it('should still work without specifying list of keys', () => {
const tree: KDTree<Point> = new KDTree(pointList)
const treeKeys = new Set(tree.getKeys())
const keys = new Set(['x', 'y'])
expect(treeKeys).to.eql(keys)
})
})
class TreeNode<T> {
private left: TreeNode<T>
private right: TreeNode<T>
private data: T
constructor(data: T) {
this.setData(data)
this.setLeft(null)
this.setRight(null)
}
setLeft(left: TreeNode<T>): void {
this.left = left
}
setRight(right: TreeNode<T>): void {
this.right = right
}
setData(data: T): void {
this.data = data
}
getLeft(): TreeNode<T> {
return this.left
}
getRight(): TreeNode<T> {
return this.right
}
getData(): T {
return this.data
}
}
function deepEqual<T>(a: T, b: T): boolean {
if (typeof a == 'object' && a != null && typeof b == 'object' && b != null) {
var count = [0, 0]
for (var key in a) count[0]++
for (var key in b) count[1]++
if (count[0] - count[1] != 0) {
return false
}
for (var key in a) {
if (!(key in b) || !deepEqual(a[key], b[key])) {
return false
}
}
for (var key in b) {
if (!(key in a) || !deepEqual(b[key], a[key])) {
return false
}
}
return true
} else {
return a === b
}
}
interface Champion<T> {
distance: number
data: T
}
interface DistanceFunc<T> {
(arg0: T, arg1: T): number
}
class KDTree<T extends {}> {
private root: TreeNode<T>
private keys: string[]
constructor(ls: T[], useKeys?: string[]) {
if (ls.length === 0) {
this.root = null
} else {
// use specified keys if provided
// defaults to all object keys otherwise
this.setKeys(useKeys?.length ? useKeys : Object.keys(ls[0]))
const sorted = this.keys.reduce((acc: T[][], key: string) => {
// return list sorted according to each key
return [
...acc,
ls.slice(0).sort((a, b) => {
if (a[key] > b[key]) {
return 1
} else if (a[key] < b[key]) {
return -1
}
return 0
}),
]
}, [])
this.setRoot(this.buildTree(sorted, this.keys, 0))
}
}
private buildTree(ls: T[][], keys: string[], depth: number): TreeNode<T> {
if (ls[0].length === 0) {
return null
}
if (ls[0].length === 1) {
return new TreeNode(ls[0][0])
}
const key = keys[depth % keys.length]
const currentList = ls[depth % keys.length]
const middleIndex = Math.floor(currentList.length / 2)
const currentPoint = currentList[middleIndex]
const currentNode = new TreeNode(currentPoint)
const left: T[][] = ls.reduce((acc: T[][], xs: T[]) => {
return [
...acc,
xs.filter((point) => {
return point[key] < currentPoint[key]
}),
]
}, [])
const right: T[][] = ls.reduce((acc: T[][], xs: T[]) => {
return [
...acc,
xs.filter((point) => {
return (
point[key] >= currentPoint[key] && !deepEqual(point, currentPoint)
)
}),
]
}, [])
currentNode.setLeft(this.buildTree(left, keys, depth + 1))
currentNode.setRight(this.buildTree(right, keys, depth + 1))
return currentNode
}
private setRoot(root: TreeNode<T>): void {
this.root = root
}
private setKeys(keys: string[]): void {
if (keys.length === 0) {
throw new Error('Object should not have depth of 0')
}
this.keys = keys
}
getKeys(): string[] {
return this.keys
}
// given a point, find the nearest point to it in the KD Tree
getNearest(queryPoint: T, getDistance: DistanceFunc<T>): T | null {
if (!this.root) {
// if tree is empty
return null
}
return this.getNearestH(
this.root,
null,
queryPoint,
this.keys,
getDistance,
0,
).data
}
private getNearestH(
curNode: TreeNode<T>,
champion: Champion<T>,
queryPoint: T,
keys: string[],
getDistance: DistanceFunc<T>,
depth: number,
): Champion<T> {
if (!curNode) {
return champion
}
const curDistance = getDistance(queryPoint, curNode.getData())
// maintain champion as least distance from point in tree to queryPoint
let curChampion: Champion<T> =
!!champion && champion.distance < curDistance
? champion
: { distance: curDistance, data: curNode.getData() }
const key = keys[depth % keys.length]
const borderPoint = {
...queryPoint,
[key]: curNode.getData()[key],
}
// calculate shortest path to current node's plane
const borderDistance = getDistance(borderPoint, queryPoint)
if (queryPoint[key] < curNode.getData()[key]) {
// go left
curChampion = this.getNearestH(
curNode.getLeft(),
curChampion,
queryPoint,
keys,
getDistance,
depth + 1,
)
// if hypersphere intersects plane
if (curChampion.distance > borderDistance) {
// still need to explore right subtree
curChampion = this.getNearestH(
curNode.getRight(),
curChampion,
queryPoint,
keys,
getDistance,
depth + 1,
)
}
} else {
// go right
curChampion = this.getNearestH(
curNode.getRight(),
curChampion,
queryPoint,
keys,
getDistance,
depth + 1,
)
// if hypersphere intersects plane
if (curChampion.distance > borderDistance) {
// still need to explore left subtree
curChampion = this.getNearestH(
curNode.getLeft(),
curChampion,
queryPoint,
keys,
getDistance,
depth + 1,
)
}
}
return curChampion
}
}
export { KDTree, DistanceFunc }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment