Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
//
// Simple unique lock guard to detect concurrent access to unsynchronized global data.
//
// This can be a useful tool to debug data races and identify places where a mutex should
// be introduced. Once the race is fixed, you'll likely want to strip out this code...
//
// --------------------------------------------------------
#include <atomic>
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <thread>
// --------------------------------------------------------
class UniqueLockFlag final
{
public:
UniqueLockFlag()
: m_lockCount{ 0 }
{ }
void lock() { m_lockCount++; }
void unlock() { m_lockCount--; }
bool isLocked() const { return m_lockCount == 1; }
private:
std::atomic<int> m_lockCount;
};
// Or use std::unique_lock<Mtx>
struct UniqueLockChecker final
{
UniqueLockChecker(UniqueLockFlag & l)
: m_lock{ l }
{
if (m_lock.isLocked())
{
std::fprintf(stderr, "Data race detected between threads! Terminating the program now...\n");
std::abort();
}
m_lock.lock();
}
~UniqueLockChecker()
{
m_lock.unlock();
}
UniqueLockFlag & m_lock;
};
// --------------------------------------------------------
static int g_guardedVariable{ 0 };
static UniqueLockFlag g_guardedVariableLock{};
void incrementGlobalVar()
{
UniqueLockChecker l{ g_guardedVariableLock };
g_guardedVariable++;
// Do some IO, so it is more likely there will be a context switch while we hold the lock.
std::printf("Global Var = %i\n", g_guardedVariable);
}
// --------------------------------------------------------
int main()
{
// Safe, incrementing from the same thread.
{
incrementGlobalVar();
incrementGlobalVar();
incrementGlobalVar();
incrementGlobalVar();
assert(g_guardedVariable == 4);
}
// Safe, run a thread and wait completion.
{
std::thread t0{ &incrementGlobalVar };
t0.join();
std::thread t1{ &incrementGlobalVar };
t1.join();
assert(g_guardedVariable == 6);
}
// Unsafe, attempt concurrent access from different threads - might abort
{
constexpr int NumRuns = 20;
std::thread t{ []() {
for (int i = 0; i < NumRuns; ++i)
{
incrementGlobalVar();
}
} };
for (int i = 0; i < NumRuns; ++i)
{
incrementGlobalVar();
}
t.join();
assert(g_guardedVariable == (NumRuns * 2) + 6);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment