Skip to content

Instantly share code, notes, and snippets.

@plasmatic1
Last active January 9, 2019 21:59
Show Gist options
  • Save plasmatic1/0f1c1593504bddb0bfb2def797398515 to your computer and use it in GitHub Desktop.
Save plasmatic1/0f1c1593504bddb0bfb2def797398515 to your computer and use it in GitHub Desktop.
A Treap Implementation
//============================================================================
// Name : BinarySearchTreeTest.cpp
// Author : Daxi the Taxi
// Version :
// Copyright : ALL YOUR CODE IS BELONG TO US
// Description : Hello World in C++, Ansi-style
//============================================================================
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define SIZE(node) ((node) == nullptr ? 0 : (node)->size)
const int MIN_PRI = 0, MAX_PRI = INT_MAX;
mt19937 twister;
uniform_int_distribution<> dis;
inline void init(){
random_device seeder;
twister = mt19937(seeder());
dis = uniform_int_distribution<>(MIN_PRI, MAX_PRI);
}
template <class T>
struct node{
node *lhs, *rhs;
int pri, size, count;
T val;
node() : lhs(nullptr), rhs(nullptr), size(-1), count(1), val(0){
pri = dis(twister);
}
};
template <class T>
struct treap{
node<T> *root;
treap(){
root = new node<T>;
}
node<T>* rotateright(node<T> *croot, node<T> *cpar){ // Assumes lhs is not nullptr
//a, b, and c correspond to the three nodes on the wikipedia tree rotation page
node<T> *pivot = croot->lhs, *bnode = pivot->rhs;
int a = SIZE(pivot->lhs), b = SIZE(bnode), c = SIZE(croot->rhs);
croot->size = b + c + croot->count;
pivot->size = a + b + c + croot->count + pivot->count;
croot->lhs = bnode;
pivot->rhs = croot;
if(cpar != nullptr){
if(croot == cpar->lhs) cpar->lhs = pivot;
else cpar->rhs = pivot;
}
return pivot;
}
node<T>* rotateleft(node<T> *croot, node<T> *cpar){ // Assumes rhs is not nullptr
//a, b, and c correspond to the three nodes on the wikipedia tree rotation page
node<T> *pivot = croot->rhs, *bnode = pivot->lhs;
int a = SIZE(croot->lhs), b = SIZE(bnode), c = SIZE(pivot->rhs);
croot->size = a + b + croot->count;
pivot->size = a + b + c + croot->count + pivot->count;
croot->rhs = bnode;
pivot->lhs = croot;
if(cpar != nullptr){
if(croot == cpar->lhs) cpar->lhs = pivot;
else cpar->rhs = pivot;
}
return pivot;
}
inline void insert(T v){
rinsert(v, root, nullptr);
}
void rinsert(T v, node<T>* &curr, node<T> *cpar){
if(curr->size == -1){
curr->val = v;
curr->size = 1;
return;
}
curr->size++;
if(v < curr->val){
if(curr->lhs == nullptr) curr->lhs = new node<T>;
rinsert(v, curr->lhs, curr);
if(curr->lhs->pri > curr->pri) curr = rotateright(curr, cpar);
}
else if(v > curr->val){
if(curr->rhs == nullptr) curr->rhs = new node<T>;
rinsert(v, curr->rhs, curr);
if(curr->rhs->pri > curr->pri) curr = rotateleft(curr, cpar);
}
else{ // v == curr->val
curr->count++;
}
}
inline void remove(T v){
rremove(v, root, nullptr);
}
void rremove(T v, node<T>* &curr, node<T> *cpar){
curr->size--;
if(v < curr->val){
if(curr->lhs == nullptr) return;
rremove(v, curr->lhs, curr);
}
else if(v > curr->val){
if(curr->rhs == nullptr) return;
rremove(v, curr->rhs, curr);
}
else{ // v == curr->val
if(curr->count > 1){
curr->count--;
}
else{
if(cpar == nullptr && (curr->lhs == nullptr || curr->rhs == nullptr)){
if(curr->lhs == nullptr && curr->rhs == nullptr){ // It's the only node
curr->size = -1;
curr->count = 0;
}
else { // 1 Child
node<T> *child = curr->lhs != nullptr ? curr->lhs : curr->rhs;
delete curr;
root = child;
}
}
else{
_rremove(curr, cpar);
}
}
}
}
void _rremove(node<T>* &curr, node<T> *cpar){
if(curr->lhs == nullptr && curr->rhs == nullptr){
if(curr == cpar->lhs) cpar->lhs = nullptr;
else cpar->rhs = nullptr;
delete curr;
}
else if(curr->lhs != nullptr && curr->rhs != nullptr){
if(curr->lhs->pri > curr->rhs->pri){ // Make left on top (Curr becomes rhs)
curr = rotateright(curr, cpar);
curr->size--;
_rremove(curr->rhs, curr);
}
else{ // Make right on top (Curr becomes lhs)
curr = rotateleft(curr, cpar);
curr->size--;
_rremove(curr->lhs, curr);
}
}
else{ // One Child Policy
node<T> *child = curr->lhs != nullptr ? curr->lhs : curr->rhs;
delete curr;
if(curr == cpar->lhs) cpar->lhs = child;
else cpar->rhs = child;
}
}
inline int search(T v){
if(root->size == -1) return -1;
return rsearch(v, root, 0);
}
int rsearch(T v, node<T> *curr, int idx){
if(v < curr->val){
if(curr->lhs == nullptr) return -1;
return rsearch(v, curr->lhs, idx);
}
else if(v > curr->val){
if(curr->rhs == nullptr) return -1;
return rsearch(v, curr->rhs, idx + SIZE(curr->lhs) + curr->count);
}
else{
return idx + SIZE(curr->lhs) + 1;
}
}
inline T byrank(int idx){
if(idx > root->size){
return -1;
}
return rbyrank(idx, root);
}
T rbyrank(int idx, node<T> *curr){
if(idx > SIZE(curr->lhs) + curr->count){
return rbyrank(idx - SIZE(curr->lhs) - curr->count, curr->rhs);
}
else if(idx > SIZE(curr->lhs)){
return curr->val;
}
return rbyrank(idx, curr->lhs);
}
int its = 0;
inline void treeorder(){
its = 0;
rtreeorder(root, 0);
}
void rtreeorder(node<T> *curr, int cn){
if(its++>15)return;
for(int i = 0; i < cn; i++) printf("-");
printf(" %d, pri: %d\n", curr->val, curr->pri);
if(curr->lhs != nullptr) rtreeorder(curr->lhs, cn + 1);
else{
for(int i = 0; i <= cn; i++) printf("-");
printf(" nullptr\n");
}
if(curr->rhs != nullptr) rtreeorder(curr->rhs, cn + 1);
else{
for(int i = 0; i <= cn; i++) printf("-");
printf(" nullptr\n");
}
}
inline void inorder(){
rinorder(root);
printf("\n");
}
void rinorder(node<T> *curr){
if(curr->lhs != nullptr) rinorder(curr->lhs);
for(int i = 0; i < curr->count; i++) printf("%d ", curr->val);
if(curr->rhs != nullptr) rinorder(curr->rhs);
}
};
treap<int> tree;
int n, m, buf, cnt = 0;
void ainsert(int v){
tree.insert(v);
printff("Query #%d: Insert %d\n", cnt++, v);
// tree.inorder();
}
void adelete(int v){
tree.remove(v);
printff("Query #%d: Delete %d\n", cnt++, v);
// tree.inorder();
}
void asearch(int v, int expected){
int ans = tree.search(v);
printff("Query #%d: Search %d, Ans: %d -- ", cnt++, v, ans);
if(ans == expected){
printff("CORRECT\n");
}
else{
printff("INCORRECT -- Correct Answer: %d\n", expected);
}
// tree.inorder();
}
void abyrank(int v, int expected){
int ans = tree.byrank(v);
printff("Query #%d: Byrank %d, Ans: %d -- ", cnt++, v, ans);
if(ans == expected){
printff("CORRECT\n");
}
else{
printff("INCORRECT -- Correct Answer: %d\n", expected);
}
}
int main(){
ios_base::sync_with_stdio(false);
cin.tie(0);
init();
printff(" --- Initialized! --- \n");
ainsert(49);
ainsert(37);
ainsert(32);
ainsert(69);
ainsert(73);
ainsert(9);
ainsert(82);
ainsert(11);
ainsert(86);
ainsert(27);
printff("\n --- Inserts Complete, Beginning Normal Operations --- \n");
asearch(49, 6);
ainsert(5);
adelete(82);
abyrank(4, 27);
ainsert(71);
asearch(77, -1);
adelete(37);
abyrank(4, 27);
abyrank(9, 73);
asearch(47, -1);
adelete(71);
ainsert(61);
abyrank(6, 49);
adelete(9);
ainsert(100);
asearch(73, 8);
abyrank(3, 27);
adelete(61);
ainsert(66);
asearch(27, 3);
adelete(32);
ainsert(25);
abyrank(4, 27);
asearch(69, 7);
asearch(25, 3);
ainsert(85);
adelete(25);
abyrank(3, 27);
adelete(27);
asearch(5, 1);
abyrank(9, 100);
ainsert(84);
// cin >> n;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment