Skip to content

Instantly share code, notes, and snippets.

@skyzh
Last active April 27, 2019 14:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save skyzh/d7d841e09e06cf908866bc06b38d8398 to your computer and use it in GitHub Desktop.
Save skyzh/d7d841e09e06cf908866bc06b38d8398 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <climits>
#include <cstring>
using namespace std;
template<typename T>
struct BST {
struct Node {
T x;
Node *l, *r;
Node(Node *l = nullptr, Node *r = nullptr) : l(l), r(r) {}
Node(const T &x, Node *l = nullptr, Node *r = nullptr) : x(x), l(l), r(r) {}
void debug(int depth = 0) {
return;
for (int i = 0; i < depth; i++) cout << " ";
cout << x << endl;
for (int i = 0; i < depth; i++) cout << " ";
cout << "L" << endl;
if (l) l->debug(depth + 1);
for (int i = 0; i < depth; i++) cout << " ";
cout << "R" << endl;
if (r) r->debug(depth + 1);
}
} *root;
BST() : root(nullptr) {}
void clear(Node *ptr) {
if (!ptr) return;
clear(ptr->l);
clear(ptr->r);
delete ptr;
}
bool find(Node *ptr, const T &x) {
if (!ptr) return false;
if (ptr->x == x) return true;
return find(ptr->l, x) || find(ptr->r, x);
}
bool find(const T &x) {
return find(root, x);
}
Node *insert(Node *ptr, const T &x) {
if (!ptr) return new Node(x);
if (x <= ptr->x) ptr->l = insert(ptr->l, x);
if (x > ptr->x) ptr->r = insert(ptr->r, x);
return ptr;
}
void insert(const T &x) {
root = insert(root, x);
}
Node *find_ith(Node *ptr, int &i) {
if (!ptr) return nullptr;
Node *l = find_ith(ptr->l, i);
if (l) return l;
if (i == 1) return ptr;
--i;
Node *r = find_ith(ptr->r, i);
if (r) return r;
return nullptr;
}
Node *find_ith(int i) {
return find_ith(root, i);
}
void delete_less_than(const T &x) {
delete_interval(INT_MIN, x);
}
void delete_greater_than(const T &x) {
delete_interval(x, INT_MAX);
}
Node *delete_node_at(Node *ptr) {
if (!ptr->l) {
return ptr->r;
} else {
Node *prev = nullptr, *c = ptr->l;
while (c->r) {
prev = c;
c = c->r;
}
if (!prev) {
c->r = ptr->r;
return c;
}
prev->r = delete_node_at(c);
c->l = ptr->l;
c->r = ptr->r;
return c;
}
}
Node *delete_node(Node *ptr, const T &x) {
if (!ptr) return nullptr;
if (x < ptr->x) ptr->l = delete_node(ptr->l, x);
if (x == ptr->x) {
Node *result = delete_node_at(ptr);
delete ptr;
return result;
}
if (x > ptr->x) ptr->r = delete_node(ptr->r, x);
return ptr;
}
void delete_node(const T &x) {
root = delete_node(root, x);
}
Node *delete_interval(Node *ptr, const T &x1, const T &x2, const T &t1, const T &t2) {
if (!ptr) return nullptr;
if (t1 <= x1 && x2 <= t2) {
clear(ptr);
return nullptr;
}
ptr->l = delete_interval(ptr->l, x1, ptr->x, t1, t2);
ptr->r = delete_interval(ptr->r, ptr->x + 1, x2, t1, t2);
if (t1 <= ptr->x && ptr->x <= t2) {
Node *tmp = delete_node_at(ptr);
delete ptr;
return tmp;
}
return ptr;
}
Node *delete_interval(const T &t1, const T &t2) {
root = delete_interval(root, INT_MIN, INT_MAX, t1, t2);
}
};
int main() {
/*
"insert": 插入, 接下来一个整数, x, 表示被插入的元素
"delete": 删除, 接下来一个整数, x, 表示被删除的元素(若树中有重复删除任意一个)
"delete_less_than": 删除小于 x 的所有元素, 接下来一个整数, x
"delete_greater_than": 删除大于 x 的所有元素, 接下来一个整数, x
"delete_interval": 删除大于 a 且小于 b 的所有元素, 接下来两个整数, a, b
"find": 查找, 接下来一个整数, x, 表示被查找的元素
"find_ith": 查找第 i 小的元素, 接下来一个整数, i
*/
char cmd[100];
int N;
int op1, op2;
BST<int> tree;
cin >> N;
for (int i = 0; i < N; i++) {
cin >> cmd;
if (strcmp(cmd, "insert") == 0) {
cin >> op1;
tree.insert(op1);
tree.root->debug();
} else if (strcmp(cmd, "delete") == 0) {
cin >> op1;
tree.delete_node(op1);
tree.root->debug();
} else if (strcmp(cmd, "delete_less_than") == 0) {
cin >> op1;
tree.delete_less_than(op1 - 1);
tree.root->debug();
} else if (strcmp(cmd, "delete_greater_than") == 0) {
cin >> op1;
tree.delete_greater_than(op1 + 1);
tree.root->debug();
} else if (strcmp(cmd, "delete_interval") == 0) {
cin >> op1 >> op2;
tree.delete_interval(op1 + 1, op2 - 1);
tree.root->debug();
} else if (strcmp(cmd, "find") == 0) {
cin >> op1;
if (tree.find(op1)) cout << "Y" << endl; else cout << "N" << endl;
} else if (strcmp(cmd, "find_ith") == 0) {
cin >> op1;
BST<int>::Node *ith = tree.find_ith(op1);
if (ith) cout << ith->x << endl; else cout << "N" << endl;
}
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment