Created
January 8, 2016 20:57
-
-
Save fferri/38a564b3bf304b3cddf6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Foundation | |
public protocol KDPoint { | |
func distance(other: Self) -> Double | |
func dimension() -> Int | |
func lessThan(other: Self, dim: Int) -> Bool | |
} | |
public enum BinaryTreeVisitOrder { | |
case InOrder | |
case PreOrder | |
case PostOrder | |
} | |
struct KDTreeNodeInfo { | |
let count: Int | |
let height: Int | |
let balanced: Bool | |
} | |
public protocol NumericType { | |
var doubleValue: Double {get} | |
} | |
extension Int : NumericType { | |
public var doubleValue: Double { | |
return Double(self) | |
} | |
} | |
extension Float : NumericType { | |
public var doubleValue: Double { | |
return Double(self) | |
} | |
} | |
extension Double : NumericType { | |
public var doubleValue: Double { | |
return self | |
} | |
} | |
public struct KDPointImpl<T: NumericType, PayloadType> : KDPoint { | |
public let values: [T] | |
public let payload: PayloadType | |
public init(values v: [T], payload p: PayloadType) { | |
values = v | |
payload = p | |
} | |
public func distance(p: KDPointImpl<T, PayloadType>) -> Double { | |
return sqrt(zip(values, p.values).map{a, b in pow(a.doubleValue - b.doubleValue, 2)}.reduce(0, combine: +)) | |
} | |
public func dimension() -> Int { | |
return values.count | |
} | |
public func lessThan(other: KDPointImpl<T, PayloadType>, dim: Int) -> Bool { | |
return values[dim].doubleValue < other.values[dim].doubleValue | |
} | |
} | |
public enum KDTree<T: KDPoint> { | |
case Leaf | |
indirect case Node(value: T, fields: Any, left: KDTree<T>, right: KDTree<T>) | |
static func makeNode(value v: T, left l: KDTree<T> = .Leaf, right r: KDTree<T> = .Leaf) -> KDTree<T> { | |
let f = KDTreeNodeInfo( | |
count: 1 + l.count + r.count, | |
height: 1 + max(l.height, r.height), | |
balanced: l.balanced && r.balanced && abs(l.height - r.height) <= 1) | |
return .Node(value: v, fields: f, left: l, right: r) | |
} | |
public static func fromPoints(points: [T], _ depth: Int = 0) -> KDTree<T> { | |
if points.count <= 1 { | |
return points.isEmpty ? .Leaf : makeNode(value: points[0]) | |
} | |
let axis = depth % points[0].dimension() | |
let sortedPoints = points.sort{a1, a2 in return a1.lessThan(a2, dim: axis)} | |
let median = Int(sortedPoints.count/2) | |
return makeNode( | |
value: sortedPoints[median], | |
left: fromPoints(Array(sortedPoints[0..<median]), depth + 1), | |
right: fromPoints(Array(sortedPoints[median+1..<sortedPoints.count]), depth + 1)) | |
} | |
public func nearestNeighbor(p: T, nearest: T? = nil, minDist: Double = Double.infinity, _ depth: Int = 0) -> T? { | |
switch(self) { | |
case .Leaf: | |
return nearest | |
case .Node(let value, _, let left, let right): | |
let d = value.distance(p) | |
let nearest1 = d < minDist ? value : nearest | |
let minDist1 = d < minDist ? d : minDist | |
let axis = depth % value.dimension() | |
let subtree = p.lessThan(value, dim: axis) ? left : right | |
return subtree.nearestNeighbor(p, nearest: nearest1, minDist: minDist1, depth + 1) | |
} | |
} | |
public func insert(p: T, _ depth: Int = 0) -> KDTree<T> { | |
switch(self) { | |
case .Leaf: | |
return KDTree.makeNode(value: p) | |
case .Node(let value, _, let left, let right): | |
let axis = depth % value.dimension() | |
if p.lessThan(value, dim: axis) { | |
return KDTree.makeNode(value: value, left: left.insert(p, depth + 1), right: right) | |
} else { | |
return KDTree.makeNode(value: value, left: left, right: right.insert(p, depth + 1)) | |
} | |
} | |
} | |
var fields: KDTreeNodeInfo { | |
switch(self) { | |
case .Leaf: return KDTreeNodeInfo(count: 0, height: 0, balanced: true) | |
case .Node(_, let f, _, _): return f as! KDTreeNodeInfo | |
} | |
} | |
public var count: Int {return fields.count} | |
public var height: Int {return fields.height} | |
public var balanced: Bool {return fields.balanced} | |
public var shape: Double { | |
let optimalHeight = log2(Double(1 + count)) | |
let actualHeight = Double(height) | |
return optimalHeight / actualHeight | |
} | |
public func rebalance(threshold t: Double = 1.0) -> KDTree<T> { | |
if !balanced && shape < t { | |
var points = [T]() | |
visit(.InOrder){e in points.append(e)} | |
return KDTree.fromPoints(points) | |
} else { | |
return self | |
} | |
} | |
public func visit(order: BinaryTreeVisitOrder, _ visitor: (T -> Void)) -> Void { | |
switch(self) { | |
case .Leaf: break | |
case .Node(let value, _, let left, let right): | |
if order == .PreOrder {visitor(value)} | |
left.visit(order, visitor) | |
if order == .InOrder {visitor(value)} | |
right.visit(order, visitor) | |
if order == .PostOrder {visitor(value)} | |
} | |
} | |
public func printFormatted(depth: Int = 0) { | |
switch(self) { | |
case .Leaf: break | |
case .Node(let value, _, let left, let right): | |
let indent = String.init(count: depth*4, repeatedValue: Character(" ")) | |
print("\(indent)\(value) (count=\(count), height=\(height), balanced=\(balanced), shape=\(shape))") | |
left.printFormatted(depth + 1) | |
right.printFormatted(depth + 1) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@vccabral Yep, I ended up taking implementation ideas from https://www.objc.io/books/functional-swift/ building a full pod.
I recently even gave a talk about the results: https://youtu.be/CwcEjxRtn18
PR for changes or new features very welcome!