Skip to content

Instantly share code, notes, and snippets.

@lochbrunner
Last active April 22, 2020 14:01
Show Gist options
  • Save lochbrunner/8d88bc5d6b03f0ada0180d0a34183238 to your computer and use it in GitHub Desktop.
Save lochbrunner/8d88bc5d6b03f0ada0180d0a34183238 to your computer and use it in GitHub Desktop.
Example of using a generic Thread Pool for CUDA like kernel calls
#pragma once
#include <cstdint>
void foo(const uint64_t begin, uint64_t *result)
{
uint64_t prev[] = {begin, 0};
for (uint64_t i = 0; i < 1000000000; ++i)
{
const auto tmp = (prev[0] + prev[1]) % 1000;
prev[1] = prev[0];
prev[0] = tmp;
}
*result = prev[0];
}
#include <chrono>
#include <condition_variable>
#include <functional>
#include <mutex>
#include <thread>
#include <vector>
#include "foo.hpp"
namespace {
class NullBuffer : public std::streambuf {
public:
int overflow(int c) { return c; }
};
NullBuffer null_buffer;
std::ostream null_stream(&null_buffer);
} // namespace
// #define VERBOSE
#ifdef VERBOSE
#define LOG std::cerr
#else
#define LOG null_stream
#endif
namespace detail {
template <size_t argIndex, size_t argSize, typename... Args, typename... Unpacked, typename F>
inline typename std::enable_if<(argIndex == argSize), void>::type apply_args_impl(const std::tuple<Args...>& t, F f,
Unpacked&&... u) {
f(u...); // I think this should be f(std::forward<Unpacked>
}
template <size_t argIndex, size_t argSize, typename... Args, typename... Unpacked, typename F>
inline typename std::enable_if<(argIndex < argSize), void>::type apply_args_impl(const std::tuple<Args...>& t, F f,
Unpacked&&... u) {
apply_args_impl<argIndex + 1, argSize>(t, f, u..., std::get<argIndex>(t));
}
} // namespace detail
template <typename CallArg, typename... Args, typename F>
inline void apply_args(const CallArg& arg, const std::tuple<Args...>& t, F f) {
detail::apply_args_impl<0, sizeof...(Args)>(t, f, arg);
}
struct ThreadIndex {
int index;
size_t dim;
};
template <typename... Args>
class ThreadPool {
public:
using Function = typename std::function<void(const ThreadIndex, Args...)>;
private:
size_t _size;
std::vector<std::thread> _workers;
std::condition_variable _finished_cv;
std::condition_variable _next_cycle;
volatile bool _terminate;
volatile int _cycle_id;
std::tuple<Args...> _buffers;
std::vector<bool> _finish_flags;
std::mutex _cv_m;
std::mutex _cv_finish;
Function _func;
void worker_loop(int worker_id) {
{
std::unique_lock<std::mutex> lk(_cv_m);
_finish_flags[worker_id] = true;
_finished_cv.notify_one();
LOG << "#" << worker_id << " ready" << std::endl;
}
while (true) {
std::unique_lock<std::mutex> lk(_cv_m);
_next_cycle.wait(lk);
LOG << "#" << worker_id << (_terminate ? " terminates" : " works") << std::endl;
if (_terminate) break;
lk.unlock();
apply_args(ThreadIndex{worker_id, _size}, _buffers, _func);
_finish_flags[worker_id] = true;
LOG << "#" << worker_id << " finish" << std::endl;
_finished_cv.notify_one();
}
};
inline void wait_for_workers() {
std::unique_lock<std::mutex> lk(_cv_finish);
_finished_cv.wait(lk, [this] {
return std::all_of(_finish_flags.begin(), _finish_flags.end(), [](bool finished) { return finished; });
});
}
public:
explicit ThreadPool(size_t size) : _terminate(false) {
for (int i = 0; i < size; ++i) {
_workers.push_back(std::thread(&ThreadPool::worker_loop, this, i));
_finish_flags.push_back(false);
}
wait_for_workers();
}
explicit ThreadPool() : _terminate(false) {
const auto size = std::thread::hardware_concurrency();
LOG << "Detected " << size << " cores" << std::endl;
for (int i = 0; i < size; ++i) {
_workers.push_back(std::thread(&ThreadPool::worker_loop, this, i));
_finish_flags.push_back(false);
}
wait_for_workers();
}
void terminate() {
_terminate = true;
{
std::unique_lock<std::mutex> lk(_cv_m);
LOG << "terminate" << std::endl;
}
_next_cycle.notify_all();
for (auto& worker : _workers) {
worker.join();
}
}
void process(Function func, Args... args) {
_func = func;
++_cycle_id;
_buffers = std::tie<Args...>(args...);
for (int i = 0; i < _workers.size(); ++i) {
_finish_flags[i] = false;
}
LOG << "trigger..." << std::endl;
_next_cycle.notify_all();
wait_for_workers();
LOG << "...process finished" << std::endl;
}
};
void foo_dispatcher(const ThreadIndex index, const uint64_t* input, uint64_t* output) {
if (index.index > 4) {
std::this_thread::sleep_for(std::chrono::microseconds(10));
return;
}
foo(input[index.index], &output[index.index]);
}
void batch(ThreadPool<uint64_t*, uint64_t*>& pool) {
std::vector<uint64_t> input{1, 2, 3, 4, 6};
std::vector<uint64_t> output(5);
pool.process(foo_dispatcher, &input[0], &output[0]);
std::cerr << "Processed: ";
for (const auto& o : output) std::cerr << " " << o;
std::cerr << std::endl;
}
in main(int, char**) {
ThreadPool<uint64_t*, uint64_t*> pool;
// Wait for startup
std::this_thread::sleep_for(std::chrono::microseconds(1));
batch(pool);
batch(pool);
batch(pool);
std::this_thread::sleep_for(std::chrono::microseconds(1));
pool.terminate();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment