Skip to content

Instantly share code, notes, and snippets.

@kerrytazi
Last active June 12, 2023 19:54
Show Gist options
  • Save kerrytazi/9e0cc11ee437833d50a136d4ffdd18a8 to your computer and use it in GitHub Desktop.
Save kerrytazi/9e0cc11ee437833d50a136d4ffdd18a8 to your computer and use it in GitHub Desktop.
Safe wrapper of mutex that won't allow using data without proper lock.
#pragma once
#include <mutex>
#include <thread>
#include <new>
template <typename TObj, typename TMutex = std::mutex>
class MutexedGuard;
template <typename TObj, typename TMutex = std::mutex>
class Mutexed
{
public:
using object_type = TObj;
using mutex_type = TMutex;
using guard_type = MutexedGuard<TObj, TMutex>;
using mutexed_type = Mutexed<TObj, TMutex>;
friend class guard_type;
private:
TObj _obj;
TMutex _mtx;
public:
Mutexed(TObj &&obj = TObj{}) : _obj(std::move(obj)) {}
guard_type lock() { return guard_type(*this); }
};
template <typename TObj, typename TMutex>
class MutexedGuard
{
public:
using object_type = TObj;
using mutex_type = TMutex;
using guard_type = MutexedGuard<TObj, TMutex>;
using mutexed_type = Mutexed<TObj, TMutex>;
friend class mutexed_type;
private:
mutexed_type &_mutexed;
bool _locked;
MutexedGuard(mutexed_type &mutexed, bool locked = false) : _mutexed(mutexed), _locked(locked)
{
if (!_locked)
{
_mutexed._mtx.lock();
_locked = true;
}
}
public:
MutexedGuard(const MutexedGuard &) = delete;
MutexedGuard &operator = (const MutexedGuard &) = delete;
MutexedGuard(MutexedGuard &&other) : MutexedGuard(other._mutexed, true)
{
other._locked = false;
}
MutexedGuard &operator = (MutexedGuard &&other)
{
this->~MutexedGuard();
new (this) MutexedGuard(std::move(other), true);
return *this;
}
~MutexedGuard()
{
if (_locked)
_mutexed._mtx.unlock();
}
object_type &get() const { return _mutexed._obj; }
object_type &operator *() const { return get(); }
object_type *operator ->() const { return &get(); }
};
// Example of safe use
#if 0
#include <iostream>
int main()
{
Mutexed<int> safe_counter;
auto t1 = std::thread([&]() {
for (size_t i = 0; i < 100000; ++i)
safe_counter.lock().get() ++;
});
auto t2 = std::thread([&]() {
for (size_t i = 0; i < 100000; ++i)
safe_counter.lock().get() --;
});
t1.join();
t2.join();
{
auto locked = safe_counter.lock();
auto &val = locked.get();
std::cout << "Counter: " << val << "\n";
}
}
#endif
// Example of unsafe use
#if 0
#include <iostream>
int main()
{
int counter = 0;
auto t1 = std::thread([&]() {
for (size_t i = 0; i < 100000; ++i)
counter ++;
});
auto t2 = std::thread([&]() {
for (size_t i = 0; i < 100000; ++i)
counter --;
});
t1.join();
t2.join();
std::cout << "Counter: " << counter << "\n";
}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment