Skip to content

Instantly share code, notes, and snippets.

@firejox
Created October 26, 2014 15:11
Show Gist options
  • Save firejox/17d3ab55fd3a3adefd08 to your computer and use it in GitHub Desktop.
Save firejox/17d3ab55fd3a3adefd08 to your computer and use it in GitHub Desktop.
b326 (努力理解紅黑樹中)
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define MAXN 32780
typedef struct _Node *Node;
#define BLACK 1
#define RED 0
struct _Node {
unsigned int color : 1;
int num;
Node l, r, pa;
};
Node newNode() {
Node tmp = (Node)malloc(sizeof(struct _Node));
tmp->color = RED;
tmp->num = 0;
tmp->l = NULL;
tmp->r = NULL;
tmp->pa = NULL;
return tmp;
}
Node grandpa(Node no) {
if (no->pa == NULL)
return NULL;
return no->pa->pa;
}
Node uncle(Node no) {
Node p;
if ((p = grandpa(no)) == NULL)
return NULL;
else if (no->pa == p->r)
return p->l;
else
return p->r;
}
Node sibling(Node no) {
if (no->pa == NULL)
return NULL;
else if (no->pa->l == no)
return no->pa->r;
else
return no->pa->l;
}
typedef struct {
Node root;
Node NIL;
} RB;
void initRB(RB *T) {
T->root = NULL;
T->NIL = newNode();
T->NIL->color = BLACK;
}
void rotate_r(RB *T, Node p) {
Node gp = grandpa(p);
Node fa = p->pa;
Node y = p->r;
fa->l = y;
if (y != T->NIL)
y->pa = fa;
p->r = fa;
fa->pa = p;
if (T->root == fa)
T->root = p;
p->pa = gp;
if (gp != NULL) {
if (gp->l == fa)
gp->l = p;
else
gp->r = p;
}
}
void rotate_l(RB *T, Node p) {
Node gp;
Node fa;
Node y;
if (p->pa == NULL) {
T->root = p;
return;
}
gp = grandpa(p);
fa = p->pa;
y = p->l;
if (y != T->NIL)
y->pa = fa;
p->l = fa;
fa->pa = p;
if(T->root == fa)
T->root = p;
p->pa = gp;
if (gp != NULL) {
if (gp->l == fa)
gp->l = p;
else
gp->r = p;
}
}
Node getMinch(RB *T, Node p) {
while (p->l != T->NIL)
p = p->l;
return p;
}
void insert_case(RB *T, Node p) {
Node gp = grandpa(p), uc = uncle(p);
while (p->pa != NULL) {
gp = grandpa(p); uc = uncle(p);
if (p->pa->color == RED) {
if (uc->color == RED) {
p->pa->color = uc->color = BLACK;
gp->color = RED;
p = gp;
} else {
if (p->pa->r == p && gp->l == p->pa) {
rotate_l(T, p);
rotate_r(T, p);
p->color = BLACK;
p->l->color = p->r->color = RED;
} else if (p->pa->l == p && gp->r == p->pa) {
rotate_r(T, p);
rotate_l(T, p);
p->color = BLACK;
p->l->color = p->r->color = RED;
} else if (p->pa->l == p && gp->l == p->pa) {
p->pa->color = BLACK;
gp->color = RED;
rotate_r(T, p->pa);
} else if (p->pa->r == p && gp->r == p->pa) {
p->pa->color = BLACK;
gp->color = RED;
rotate_r(T, p->pa);
}
return;
}
}
}
T->root = p;
p->color = BLACK;
#if 0
if (p->pa == NULL) {
T->root = p;
p->color = BLACK;
return;
}
if (p->pa->color == RED) {
if (uc->color == RED) {
p->pa->color = uc->color = BLACK;
gp->color = RED;
insert_case(T, gp);
} else {
if (p->pa->r == p && gp->l == p->pa) {
rotate_l(T, p);
rotate_r(T, p);
p->color = BLACK;
p->l->color = p->r->color = RED;
} else if (p->pa->l == p && gp->r == p->pa) {
rotate_r(p);
rotate_l(p);
p->color = BLACK;
p->l->color = p->r->color = RED;
} else if (p->pa->l == p && gp->l == p->pa) {
p->pa->color = BLACK;
gp->color = RED;
rotate_r(T, p->pa);
} else if (p->pa->r == p && gp->r == p->pa) {
p->pa->color = BLACK;
gp->color = RED;
rotate_r(T, p->pa);
}
}
}
#endif
}
void insert(RB *T, int x) {
Node p, tmp;
if (T->root == NULL) {
T->root = newNode();
T->root->color = BLACK;
T->root->l = T->root->r = T->NIL;
T->root->num = x;
} else {
p = T->root;
while (p != T->NIL) {
if (p->num >= x) {
if (p->l != T->NIL)
p = p->l;
else {
tmp = newNode();
tmp->num = x;
tmp->l = tmp->r = T->NIL;
tmp->pa = p;
p->l = tmp;
insert_case(T, tmp);
}
} else {
if (p->r != T->NIL)
p = p->r;
else {
tmp = newNode();
tmp->num = x;
tmp->l = tmp->r = T->NIL;
tmp->pa = p;
p->r = tmp;
insert_case(T, tmp);
}
}
}
}
}
void delete_case(RB *T, Node p) {
Node sib = sibling(p);
if (p->pa == NULL) {
p->color = BLACK;
return;
}
if (sib->color == RED) {
p->pa->color = RED;
sib->color = BLACK;
if (p == p->pa->l)
rotate_l(T, sib);
else
rotate_r(T, sib);
sib = sibling(p);
}
if (sib->color == BLACK &&
sib->l->color == BLACK &&
sib->r->color == BLACK) {
sib->color = RED;
if (p->pa->color == BLACK)
delete_case(T, p->pa);
else
p->pa->color = BLACK;
} else {
if (sib->color == BLACK) {
if (p == p->pa->l &&
sib->l->color == RED &&
sib->r->color == BLACK) {
sib->color = RED;
sib->l->color = BLACK;
rotate_r(T, sib->l);
sib = sibling(p);
} else if (p == p->pa->r &&
sib->l->color == BLACK &&
sib->r->color == RED) {
sib->color = RED;
sib->r->color = BLACK;
rotate_l(T, sib->r);
sib = sibling(p);
}
}
sib->color = p->pa->color;
p->pa->color = BLACK;
if (p == p->pa->l) {
sib->r->color = BLACK;
rotate_l(T, sib);
} else {
sib->l->color = BLACK;
rotate_r(T, sib);
}
}
}
void delete1ch(RB *T, Node p) {
Node child = (p->l == T->NIL) ? p->r: p->l;
if (p->pa == NULL && p->l == T->NIL && p->r == T->NIL) {
p = NULL;
T->root = p;
return;
}
if (p->pa == NULL) {
free(p);
child->pa = NULL;
T->root = child;
T->root->color = BLACK;
return;
}
if (p->pa->l == p)
p->pa->l = child;
else
p->pa->r = child;
child->pa = p->pa;
if (p->color == BLACK) {
if (child->color == RED)
child->color = BLACK;
else
delete_case(T, child);
}
free(p);
}
int deletech(RB *T, Node p, int data) {
Node tmp;
while(p != T->NIL) {
if (p->num < data)
p = p->r;
else if (p->num > data)
p = p->l;
else if (p->r == T->NIL) {
delete1ch(T, p);
return 1;
}
else {
tmp = getMinch(T, p->r);
p->num = tmp->num;
delete1ch(T, tmp);
return 1;
}
}
return 0;
}
struct {
int l;
int r;
RB T;
} st[MAXN << 1];
int fi[MAXN];
int e[MAXN][2];
int next[MAXN << 1];
int cc;
int id[MAXN], sz[MAXN];
char w[MAXN];
int stid[MAXN];
void initST(int l, int r, int id) {
int mid = (l+r)/2;
st[id].l = l;
st[id].r = r;
initRB(&st[id].T);
if ((r - l) == 1) {
stid[l] = id;
return;
}
initST(l, mid, id*2);
initST(mid, r, id*2+1);
}
void update(int id, int l, int r, int v) {
int mid = (st[id].l + st[id].r) >> 1;
if (st[id].l == l && st[id].r == r) {
if (w[v])
insert(&st[id].T, v);
else
deletech(&st[id].T, st[id].T.root, v);
return;
}
else if (mid >= r)
update(id*2, l, r, v);
else if (mid <= l)
update(id*2 +1, l, r, v);
else {
update(id*2, l, mid, v);
update(id*2 +1, mid, r, v);
}
}
int query(int id) {
Node p = NULL;
while (id >= 1) {
if (st[id].T.root != NULL) {
for (p = st[id].T.root; p->r != st[id].T.NIL; p = p->r);
return p->num;
}
id >>= 1;
}
return -1;
}
int traversal(int n) {
int i;
w[n] = 0;
id[n] = cc++;
sz[n] = 1;
for (i = fi[n]; i != -1; i = next[i])
if (w[e[i/2][i%2]])
sz[n] += traversal(e[i/2][i%2]);
return sz[n];
}
int main(void) {
int N;
int i, Q, vt;
char ch;
while (scanf("%d", &N) != EOF) {
memset(fi, -1, N*sizeof(int));
memset(w, -1, N);
initST(0, N, 1);
for (i = 0; i < N-1; i++) {
scanf("%d%d", &e[i][0], &e[i][1]);
next[i*2] = fi[e[i][1]];
fi[e[i][1]] = i*2;
next[i*2 +1] = fi[e[i][0]];
fi[e[i][0]] = i*2 +1;
}
cc = 0;
traversal(0);
scanf("%d\n", &Q);
for (i = 0; i < Q; i++) {
ch = getchar();
scanf("%d\n", &vt);
if (ch == 'M') {
w[vt] = !w[vt];
update(1, id[vt], id[vt] + sz[vt], vt);
} else {
vt = query(stid[id[vt]]);
printf("%d\n", vt);
}
}
puts("");
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment