Skip to content

Instantly share code, notes, and snippets.

@salamantos
Created April 2, 2018 23:44
Show Gist options
  • Save salamantos/b3737e1b7f11118f5ec0c6eff0e8c26e to your computer and use it in GitHub Desktop.
Save salamantos/b3737e1b7f11118f5ec0c6eff0e8c26e to your computer and use it in GitHub Desktop.
#pragma once
#include <tpcc/stdlike/atomic.hpp>
#include <tpcc/stdlike/condition_variable.hpp>
#include <tpcc/stdlike/mutex.hpp>
#include <tpcc/support/compiler.hpp>
#include <algorithm>
#include <forward_list>
#include <functional>
#include <iterator>
#include <shared_mutex>
#include <vector>
#include <utility>
namespace tpcc {
namespace solutions {
////////////////////////////////////////////////////////////////////////////////
// implement writer-priority rwlock
class ReaderWriterLock {
public:
// reader section / shared ownership
void lock_shared() {
std::unique_lock<std::mutex> lock (mutex_);
while(write_acquire_ > 0) {
cond_var_.wait(lock);
}
readers_++;
}
void unlock_shared() {
std::unique_lock<std::mutex> lock (mutex_);
readers_--;
cond_var_.notify_all();
}
// writer section / exclusive ownership
void lock() {
std::unique_lock<std::mutex> lock (mutex_);
write_acquire_++;
while(readers_ > 0 || write_) {
cond_var_.wait(lock);
}
write_ = true;
}
void unlock() {
std::unique_lock<std::mutex> lock (mutex_);
write_acquire_--;
write_ = false;
cond_var_.notify_all();
}
private:
bool write_{false};
size_t write_acquire_{0};
size_t readers_{0};
std::mutex mutex_;
tpcc::condition_variable cond_var_;
};
////////////////////////////////////////////////////////////////////////////////
template <typename T, class HashFunction = std::hash<T>>
class StripedHashSet {
private:
using RWLock = ReaderWriterLock; // std::shared_timed_mutex
using ReaderLocker = std::shared_lock<RWLock>;
using WriterLocker = std::unique_lock<RWLock>;
using Bucket = std::forward_list<T>;
using Buckets = std::vector<Bucket>;
public:
explicit StripedHashSet(const size_t concurrency_level = 4,
const size_t growth_factor = 2,
const double max_load_factor = 0.8)
: concurrency_level_(concurrency_level),
growth_factor_(growth_factor),
max_load_factor_(max_load_factor),
locks_(concurrency_level),
buckets_(concurrency_level),
size_(0) {
}
bool Insert(T element) {
auto element_hash = hash_(element);
auto stripe_lock = LockStripe<WriterLocker>(element_hash);
if (not Contains(element)){
auto bucket = GetBucket(element_hash);
bucket.push_front(element);
size_++;
return true;
}
return false;
}
bool Remove(const T& element) {
auto element_hash = hash_(element);
auto stripe_lock = LockStripe<WriterLocker>(element_hash);
if (Contains(element)){
auto bucket = GetBucket(element_hash);
auto erase_elem_iterator = std::find(bucket.begin(), bucket.end(), element);
std::advance(erase_elem_iterator, -1);
bucket.erase_after(erase_elem_iterator);
return true;
}
}
bool Contains(const T& element) const {
size_t element_hash = hash_(element);
auto stripe_lock = LockStripe<ReaderLocker>(element_hash);
auto bucket = GetBucket(element_hash);
return std::find(bucket.begin(), bucket.end(), element) != bucket.end();
}
size_t GetSize() const {
auto stripe_lock = LockStripe<ReaderLocker>(0);
return size_;
}
size_t GetBucketCount() const {
// for testing purposes
// do not optimize, just acquire arbitrary lock and read bucket count
auto stripe_lock = LockStripe<ReaderLocker>(0);
return buckets_.size();
}
private:
size_t GetStripeIndex(const size_t hash_value) const {
return hash_value % concurrency_level_;
}
// use: auto stripe_lock = LockStripe<ReaderLocker>(hash_value);
template <class Locker>
Locker LockStripe(const size_t hash_value) const {
return Locker{locks_[GetStripeIndex(hash_value)]};
//return Locker{locks_[0]};
}
size_t GetBucketIndex(const size_t hash_value) const {
return hash_value % buckets_.size();
}
Bucket& GetBucket(const size_t hash_value) {
return buckets_[hash_value % buckets_.size()];
}
const Bucket& GetBucket(const size_t hash_value) const {
return buckets_[hash_value % buckets_.size()];
}
bool MaxLoadFactorExceeded() const {
return false; // to be implemented
}
void TryExpandTable(const size_t expected_bucket_count) {
UNUSED(expected_bucket_count);
// to be implemented
}
private:
size_t concurrency_level_;
size_t growth_factor_;
double max_load_factor_;
std::vector<RWLock> locks_;
Buckets buckets_;
HashFunction hash_;
size_t size_;
// to be continued
};
} // namespace solutions
} // namespace tpcc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment