Skip to content

Instantly share code, notes, and snippets.

@spaghetti-source
Created November 16, 2014 11:17
Show Gist options
  • Save spaghetti-source/1445f98857aaedb18aff to your computer and use it in GitHub Desktop.
Save spaghetti-source/1445f98857aaedb18aff to your computer and use it in GitHub Desktop.
Vantage Point Tree
//
// Vantage Point Tree (vp tree)
//
// each node has two childs, left and right;
// the left childs are closer than the threshold,
// and the right childs are farther than the thoreshold.
//
#include <iostream>
#include <vector>
#include <cstdio>
#include <cstdlib>
#include <complex>
#include <map>
#include <cmath>
#include <cstring>
#include <functional>
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
using namespace std;
#include <ctime>
double tick() {
static clock_t oldtick;
clock_t newtick = clock();
double diff = 1.0*(newtick - oldtick) / CLOCKS_PER_SEC;
oldtick = newtick;
return diff;
}
#define fst first
#define snd second
#define all(c) ((c).begin()), ((c).end())
typedef complex<double> point;
struct vantage_point_tree {
vector<point> ps;
struct node {
int id; // id
double th; // threshold
node *l, *r; // l contains points closer than th
} *root;
vantage_point_tree(vector<point> ps_) : ps(ps_) {
root = build(0, ps.size());
}
node *build(int l, int r) {
if (l >= r) return 0;
if (l+1 == r) return new node({l});
swap(ps[l], ps[l + rand() % (r - l)]);
int m = (l + r) / 2;
nth_element(ps.begin()+l+1, ps.begin()+m, ps.begin()+r,
[&](point p, point q) { return norm(p - ps[l]) < norm(q - ps[l]); });
return new node({l, abs(ps[l]-ps[m]), build(l+1, m), build(m, r)});
}
void closest(node *t, point p, pair<double, node*> &ub) {
if (!t) return;
double d = abs(p - ps[t->id]);
if (d < ub.fst) ub = {d, t};
if (!t->l && !t->r) return;
if (d < t->th) { // the ball contains p
closest(t->l, p, ub);
if (t->th - d < ub.fst) closest(t->r, p, ub);
} else { // the ball excludes p
closest(t->r, p, ub);
if (d - t->th < ub.fst) closest(t->l, p, ub);
}
}
point closest(point p) {
pair<double, node*> ub(1.0/0.0, 0);
closest(root, p, ub);
return ps[ub.snd->id];
}
};
int main() {
srand( 0xdeadbeef );
int n = 100000;
vector<point> ps;
for (int i = 0; i < n; ++i)
ps.push_back(point(rand()%n, rand()%n));
tick();
vantage_point_tree T(ps);
cout << "construct " << n << " points: " << tick() << "[s]" << endl;
// search
tick();
for (int i = 0; i < n; ++i)
T.closest(ps[i]);
cout << "search " << n << " points: " << tick() << "[s]" << endl;
// verify
for (int i = 0; i < 100; ++i) {
point p(rand(), rand());
point Tp = T.closest(p);
point Tq = ps[0];
for (auto q: ps)
if (norm(p - Tq) > norm(p - q)) Tq = q;
if (abs(norm(Tp - p) - norm(Tq - p)) > 1e-8) {
cout << norm(Tp - p) << endl;
cout << norm(Tq - p) << endl;
cout << "ERROR" << endl;
return 0;
}
}
cout << "verification passed" << endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment