Created
November 28, 2018 01:24
-
-
Save GoBigorGoHome/3024d55746e546ed2a329552d06aa3c9 to your computer and use it in GitHub Desktop.
KD Tree 模板
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template <int DIM, int N = 100000> | |
class KDtree { | |
#define lson id<<1 | |
#define rson id<<1|1 | |
#define sqr(x) ((x)*(x)) | |
// kd-tree | |
using coord = array<int,DIM>; | |
constexpr static auto dist2 = [](const coord &a, const coord &b) { | |
ll ans = 0; | |
for (int i = 0; i < DIM; i++) { | |
ans += sqr(a[i] - b[i]); | |
} | |
return ans; | |
}; | |
template <int dim> cmp(const coord &a, const coord &b) { | |
return a[dim] < b[dim]; | |
} | |
struct range { // 不起名作 range | |
coord l, r; | |
bool cover(const coord &x) const{ | |
for (int i = 0; i <DIM; i++) { | |
if (x[i] > r[i] || x[i] < l[i]) return false; | |
} | |
return true; | |
} | |
bool cover(const range &ran) const{ | |
for (int i = 0; i < DIM; i++) { | |
if (ran.l[i] < l[i] || ran.r[i] > r[i]) return false; | |
return true; | |
} | |
} | |
void intersect(const range &ran) { | |
for (int i = 0; i < DIM; i++) { | |
l[i] = max(l[i], ran.l[i]); | |
r[i] = min(r[i], ran.r[i]); | |
} | |
} | |
}; | |
struct Node { | |
using LL = long long; | |
coord x; | |
range ran; | |
int dim; // 维度 | |
int size; | |
Node *l, *r; | |
public: | |
void out() { | |
for (int i = 0; i < DIM; i++) | |
printf("%d%c", x[i], i == DIM - 1 ? '\n' : ' '); | |
} | |
Node() = default; | |
Node(const coord &x, int dim) : x(x), dim(dim) { | |
ran.l = ran.r = x; | |
size = 1; | |
l = r = nullptr; | |
} | |
void upd() { | |
ran.l = ran.r = x; | |
size = 1; | |
if (l) { | |
size += l->size; | |
ran.intersect(l->ran); | |
} | |
if (r) { | |
size += r->size; | |
ran.intersect(r->ran); | |
} | |
} | |
} *root, t[N], *e = t; // e: end | |
Node *build(Node p[], int l, int r, int dim) { | |
if (l <= r) { | |
int m = (l + r) / 2; | |
nth_element(p + l, p + r, p + m); | |
Node *ptr = e++; | |
*ptr = Node(p[m], dim); | |
ptr->l = build(l, m, (dim+1)%DIM); | |
ptr->r = build(m+1, r, (dim+1)%DIM); | |
return ptr; | |
} | |
return nullptr; | |
} | |
void ins(Node *){} | |
int range_query(const range &r){ | |
return 0; | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment