Created
August 18, 2022 02:03
-
-
Save water111/2e5755ff0aa85a8d76d482ede76be64b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/*! | |
* Simple thread pool. | |
*/ | |
class ThreadPoolSpinLock { | |
private: | |
std::vector<std::thread> m_threads; | |
int m_num_threads = 0; | |
std::function<void(int)> m_work_function; | |
// Thread state. These two variables define the desired states for the worker threads | |
// these are only written from the main thread, which locks these with the mutex, | |
// and does a signal. These are atomic_bool because the worker will read these without locking. | |
// should we be spinning? uses wakeup flag mutex when being set, only written from main thread. | |
std::atomic_bool m_wakeup_flag = false; | |
// should we be shutting down? wakeup will also be set to false when this is set. | |
std::atomic_bool m_shutdown_flag = false; | |
/// work synchronization. | |
/// - (assumed) nobody is working, workers are waiting on "work_ready" | |
/// - main thread clears "work_done" flags | |
/// - main thread sets "work_ready" flag | |
/// - work threads work, then set "work_done" | |
/// - main thread waits for "work_done"s. | |
std::vector<std::atomic_bool> m_work_ready; | |
std::vector<std::atomic_bool> m_done_flags; | |
std::mutex m_mutex; | |
std::condition_variable m_cv; | |
public: | |
ThreadPoolSpinLock(int num_threads, std::function<void(int)> work_function) | |
: m_num_threads(num_threads), | |
m_work_function(work_function), | |
m_done_flags(num_threads - 1), | |
m_work_ready(num_threads - 1) { | |
// ASSERT(num_thread > 0); | |
m_threads.resize(num_threads - 1); | |
for (int thread_idx = 0; thread_idx < num_threads - 1; thread_idx++) { | |
m_threads[thread_idx] = std::thread([&, thread_idx]() { | |
while (true) { // outer loop: wake/sleep of the worker. | |
std::unique_lock<std::mutex> lk(m_mutex); | |
m_cv.wait(lk, [&]() -> bool { return m_wakeup_flag || m_shutdown_flag; }); | |
if (m_shutdown_flag) { | |
return; | |
} | |
lk.unlock(); | |
// loop over work, without sleeping | |
while (true) { | |
while (!m_work_ready[thread_idx] && m_wakeup_flag) { | |
_mm_pause(); | |
} | |
if (!m_wakeup_flag) { | |
break; // sleep | |
} | |
m_work_ready[thread_idx] = false; | |
m_work_function(thread_idx); | |
m_done_flags[thread_idx] = true; | |
} | |
} | |
}); | |
} | |
} | |
void start_spinning() { | |
std::lock_guard<std::mutex> lk(m_mutex); | |
m_wakeup_flag = true; | |
m_cv.notify_all(); | |
} | |
void stop_spinning() { m_wakeup_flag = false; } | |
void work() { | |
// start workers | |
for (int i = 0; i < m_num_threads - 1; i++) { | |
m_work_ready[i] = true; | |
} | |
// our work | |
m_work_function(m_num_threads - 1); | |
// wait for done | |
while (true) { | |
bool all_done = true; | |
for (int i = 0; i < m_num_threads - 1; i++) { | |
if (!m_done_flags[i]) { | |
all_done = false; | |
_mm_pause(); | |
break; | |
} | |
} | |
if (all_done) { | |
return; | |
} | |
} | |
} | |
~ThreadPoolSpinLock() { | |
{ | |
std::lock_guard<std::mutex> lk(m_mutex); | |
m_wakeup_flag = false; | |
m_shutdown_flag = true; | |
m_cv.notify_all(); | |
} | |
for (auto& t : m_threads) { | |
t.join(); | |
} | |
} | |
}; | |
class FakeWork { | |
public: | |
FakeWork(int num) { m_results.resize(num); } | |
void do_work(int idx) { | |
int res = 50; | |
int count = 0; | |
for (int x = 0; x < res; x++) { | |
for (int y = 0; y < res; y++) { | |
if (x * x + y * y < res * res) { | |
count++; | |
} | |
} | |
} | |
m_results[idx] = 4.f * count / (res * res); | |
} | |
bool confirm_correct() { | |
for (auto f : m_results) { | |
printf("%.3f\n", f); | |
} | |
return false; | |
} | |
void reset() { | |
for (auto& x : m_results) { | |
x = 999; | |
} | |
} | |
private: | |
std::vector<float> m_results; | |
}; | |
int NRUNS = 3; | |
TEST(ThreadStuff, NoPool) { | |
FakeWork work(NRUNS); | |
for (int j = 0; j < 10; j++) { | |
Timer timer; | |
for (int i = 0; i < NRUNS; i++) { | |
work.do_work(i); | |
} | |
double us = timer.getUs(); | |
printf("TOOK %.3f us\n", us); | |
} | |
work.confirm_correct(); | |
} | |
TEST(ThreadStuff, Pool) { | |
FakeWork work(NRUNS); | |
ThreadPoolSpinLock pool(NRUNS, [&](int idx) { work.do_work(idx); }); | |
pool.start_spinning(); | |
work_microseconds(15000); | |
for (int j = 0; j < 10; j++) { | |
work.reset(); | |
Timer timer; | |
pool.work(); | |
double us = timer.getUs(); | |
printf("TOOK %.3f us\n", us); | |
} | |
work.confirm_correct(); | |
pool.stop_spinning(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment