Skip to content

Instantly share code, notes, and snippets.

@wilburding
Created April 4, 2013 15:26
Show Gist options
  • Save wilburding/5311360 to your computer and use it in GitHub Desktop.
Save wilburding/5311360 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <vector>
#include <string>
#include <mutex>
#include <future>
#include <chrono>
#include <random>
#include <utility>
#include <assert.h>
using namespace std;
enum { ITERATIONS = 10 * 1000 * 1000 };
class BinaryBackOff
{
public:
// smallest effective time is about 20us
BinaryBackOff(chrono::microseconds initial_backoff = chrono::microseconds{10},
chrono::microseconds max_backoff = chrono::microseconds{640})
:initial_backoff_(initial_backoff),
max_backoff_(max_backoff)
{}
void backoff()
{
return;
this_thread::sleep_for(initial_backoff_);
if(initial_backoff_ < max_backoff_)
{
initial_backoff_ *= 2;
}
}
private:
chrono::microseconds initial_backoff_;
chrono::microseconds max_backoff_;
};
template<class T>
class TaggedPointer
{
public:
explicit TaggedPointer(T* ptr = nullptr)
:ptr_(ptr),
counter_(0)
{}
inline T* load_ptr() noexcept
{
return ptr_.load(memory_order_acquire);
}
inline uint64_t load_counter() noexcept
{
return counter_.load(memory_order_acquire);
}
inline bool compare_exchange(T* expected_ptr, uint64_t expected_counter, T* desired_ptr, uint64_t desired_counter) noexcept
{
bool result;
asm volatile (
"lock cmpxchg16b %0;"
"setz %3;"
:"+m"(*this), "+a"(expected_ptr), "+d"(expected_counter), "=q"(result)
:"b"(desired_ptr), "c"(desired_counter)
:"cc", "memory"
);
return result;
}
private:
// atomic in case of compiler keep in registers?
atomic<T*> ptr_;
atomic<uint64_t> counter_;
} __attribute__ (( __aligned__(16) ));
struct Node
{
int value;
Node* next; // need be atomic?
};
class Stack
{
public:
inline bool try_push(Node* node)
{
Node* head = head_.load_ptr();
uint64_t counter = head_.load_counter();
//node->next.store(head, memory_order_relaxed);
node->next = head;
return head_.compare_exchange(head, counter, node, counter + 1);
}
void push(Node* node)
{
BinaryBackOff bb;
while(!try_push(node))
{
bb.backoff();
}
}
inline bool try_pop(int& value)
{
Node* head = head_.load_ptr();
uint64_t counter = head_.load_counter();
if(!head)
{
value = -1;
return true;
}
if(head_.compare_exchange(head, counter, head->next, counter + 1))
{
/*delete head;*/
value = head->value;
return true;
}
else
{
return false;
}
}
int pop()
{
int res;
BinaryBackOff bb;
while(!try_pop(res))
{
bb.backoff();
}
return res;
}
private:
TaggedPointer<Node> head_;
};
mutex stat_lock;
vector<int> total_pushes;
vector<int> total_pops;
void worker_correctness(int id, shared_ptr<Stack> stack)
{
minstd_rand rd;
uniform_int_distribution<> uid(0, 1);
vector<int> pushes;
pushes.reserve(ITERATIONS * 3 / 2);
vector<int> pops;
pops.reserve(ITERATIONS * 3 / 2);
auto begin_time = chrono::high_resolution_clock::now();
for(int i = 0; i < ITERATIONS; ++i)
{
if(uid(rd) == 0)
{
pushes.push_back(i);
stack->push(new Node{i, nullptr});
}
else
{
auto value = stack->pop();
if(value >= 0)
pops.push_back(value);
}
}
while(true)
{
auto value = stack->pop();
if(value >= 0)
pops.push_back(value);
else
break;
}
auto end_time = chrono::high_resolution_clock::now();
{
lock_guard<mutex> holder(stat_lock);
total_pushes.insert(end(total_pushes), begin(pushes), end(pushes));
total_pops.insert(end(total_pops), begin(pops), end(pops));
}
printf("%d: %lldms\n", id, chrono::duration_cast<chrono::milliseconds>(end_time - begin_time).count());
}
void test_correctness()
{
auto stack = make_shared<Stack>();
thread threads[] = {
thread{worker_correctness, 1, stack},
thread{worker_correctness, 2, stack},
thread{worker_correctness, 3, stack},
thread{worker_correctness, 4, stack}
};
for(auto& thread: threads)
{
thread.join();
}
sort(begin(total_pushes), end(total_pushes));
sort(begin(total_pops), end(total_pops));
if(total_pushes == total_pops)
{
cout << "good" << endl;
}
else
{
cout << "bad" << endl;
}
}
void worker_speed(int id, Stack& stack, Node* pool)
{
auto begin_time = chrono::high_resolution_clock::now();
for(int i = 0; i < ITERATIONS; ++i)
{
if((i & 1) == 0)
{
pool[i].value = i;
stack.push(&pool[i]);
}
else
{
stack.pop();
}
}
while(true)
{
if(stack.pop() < 0)
break;
}
auto end_time = chrono::high_resolution_clock::now();
printf("%d: %lldms\n", id, chrono::duration_cast<chrono::milliseconds>(end_time - begin_time).count());
}
void test_speed()
{
auto stack = make_shared<Stack>();
Node* pool = new Node[ITERATIONS * 4];
thread threads[] = {
thread{worker_speed, 1, ref(*stack.get()), pool},
thread{worker_speed, 2, ref(*stack.get()), pool + ITERATIONS},
thread{worker_speed, 3, ref(*stack.get()), pool + 2 * ITERATIONS},
thread{worker_speed, 4, ref(*stack.get()), pool + 3 * ITERATIONS}
};
for(auto& thread: threads)
{
thread.join();
}
}
template<class F, class ...Args>
void timeit(F f, uint32_t repeat, Args&&... args)
{
auto begin_time = chrono::high_resolution_clock::now();
for(uint32_t i = 0; i < repeat; ++i)
f(forward<Args>(args)...);
auto end_time = chrono::high_resolution_clock::now();
printf("total time: %lldus\n", chrono::duration_cast<chrono::microseconds>(end_time - begin_time).count());
printf("average time: %lldns\n", chrono::duration_cast<chrono::nanoseconds>(end_time - begin_time).count() / repeat);
}
int main(int argc, char* argv[])
{
test_correctness();
test_speed();
/*
*
* timeit(rand, 10 * 1000 * 1000);
*
* minstd_rand mr;
* uniform_int_distribution<> uid(0, 1000000);
* timeit([&uid](minstd_rand& r){ uid(r); }, 10 * 1000 * 1000, ref(mr));
*/
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment