Last active
April 22, 2020 14:01
-
-
Save lochbrunner/8d88bc5d6b03f0ada0180d0a34183238 to your computer and use it in GitHub Desktop.
Example of using a generic Thread Pool for CUDA like kernel calls
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <cstdint> | |
void foo(const uint64_t begin, uint64_t *result) | |
{ | |
uint64_t prev[] = {begin, 0}; | |
for (uint64_t i = 0; i < 1000000000; ++i) | |
{ | |
const auto tmp = (prev[0] + prev[1]) % 1000; | |
prev[1] = prev[0]; | |
prev[0] = tmp; | |
} | |
*result = prev[0]; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <chrono> | |
#include <condition_variable> | |
#include <functional> | |
#include <mutex> | |
#include <thread> | |
#include <vector> | |
#include "foo.hpp" | |
namespace { | |
class NullBuffer : public std::streambuf { | |
public: | |
int overflow(int c) { return c; } | |
}; | |
NullBuffer null_buffer; | |
std::ostream null_stream(&null_buffer); | |
} // namespace | |
// #define VERBOSE | |
#ifdef VERBOSE | |
#define LOG std::cerr | |
#else | |
#define LOG null_stream | |
#endif | |
namespace detail { | |
template <size_t argIndex, size_t argSize, typename... Args, typename... Unpacked, typename F> | |
inline typename std::enable_if<(argIndex == argSize), void>::type apply_args_impl(const std::tuple<Args...>& t, F f, | |
Unpacked&&... u) { | |
f(u...); // I think this should be f(std::forward<Unpacked> | |
} | |
template <size_t argIndex, size_t argSize, typename... Args, typename... Unpacked, typename F> | |
inline typename std::enable_if<(argIndex < argSize), void>::type apply_args_impl(const std::tuple<Args...>& t, F f, | |
Unpacked&&... u) { | |
apply_args_impl<argIndex + 1, argSize>(t, f, u..., std::get<argIndex>(t)); | |
} | |
} // namespace detail | |
template <typename CallArg, typename... Args, typename F> | |
inline void apply_args(const CallArg& arg, const std::tuple<Args...>& t, F f) { | |
detail::apply_args_impl<0, sizeof...(Args)>(t, f, arg); | |
} | |
struct ThreadIndex { | |
int index; | |
size_t dim; | |
}; | |
template <typename... Args> | |
class ThreadPool { | |
public: | |
using Function = typename std::function<void(const ThreadIndex, Args...)>; | |
private: | |
size_t _size; | |
std::vector<std::thread> _workers; | |
std::condition_variable _finished_cv; | |
std::condition_variable _next_cycle; | |
volatile bool _terminate; | |
volatile int _cycle_id; | |
std::tuple<Args...> _buffers; | |
std::vector<bool> _finish_flags; | |
std::mutex _cv_m; | |
std::mutex _cv_finish; | |
Function _func; | |
void worker_loop(int worker_id) { | |
{ | |
std::unique_lock<std::mutex> lk(_cv_m); | |
_finish_flags[worker_id] = true; | |
_finished_cv.notify_one(); | |
LOG << "#" << worker_id << " ready" << std::endl; | |
} | |
while (true) { | |
std::unique_lock<std::mutex> lk(_cv_m); | |
_next_cycle.wait(lk); | |
LOG << "#" << worker_id << (_terminate ? " terminates" : " works") << std::endl; | |
if (_terminate) break; | |
lk.unlock(); | |
apply_args(ThreadIndex{worker_id, _size}, _buffers, _func); | |
_finish_flags[worker_id] = true; | |
LOG << "#" << worker_id << " finish" << std::endl; | |
_finished_cv.notify_one(); | |
} | |
}; | |
inline void wait_for_workers() { | |
std::unique_lock<std::mutex> lk(_cv_finish); | |
_finished_cv.wait(lk, [this] { | |
return std::all_of(_finish_flags.begin(), _finish_flags.end(), [](bool finished) { return finished; }); | |
}); | |
} | |
public: | |
explicit ThreadPool(size_t size) : _terminate(false) { | |
for (int i = 0; i < size; ++i) { | |
_workers.push_back(std::thread(&ThreadPool::worker_loop, this, i)); | |
_finish_flags.push_back(false); | |
} | |
wait_for_workers(); | |
} | |
explicit ThreadPool() : _terminate(false) { | |
const auto size = std::thread::hardware_concurrency(); | |
LOG << "Detected " << size << " cores" << std::endl; | |
for (int i = 0; i < size; ++i) { | |
_workers.push_back(std::thread(&ThreadPool::worker_loop, this, i)); | |
_finish_flags.push_back(false); | |
} | |
wait_for_workers(); | |
} | |
void terminate() { | |
_terminate = true; | |
{ | |
std::unique_lock<std::mutex> lk(_cv_m); | |
LOG << "terminate" << std::endl; | |
} | |
_next_cycle.notify_all(); | |
for (auto& worker : _workers) { | |
worker.join(); | |
} | |
} | |
void process(Function func, Args... args) { | |
_func = func; | |
++_cycle_id; | |
_buffers = std::tie<Args...>(args...); | |
for (int i = 0; i < _workers.size(); ++i) { | |
_finish_flags[i] = false; | |
} | |
LOG << "trigger..." << std::endl; | |
_next_cycle.notify_all(); | |
wait_for_workers(); | |
LOG << "...process finished" << std::endl; | |
} | |
}; | |
void foo_dispatcher(const ThreadIndex index, const uint64_t* input, uint64_t* output) { | |
if (index.index > 4) { | |
std::this_thread::sleep_for(std::chrono::microseconds(10)); | |
return; | |
} | |
foo(input[index.index], &output[index.index]); | |
} | |
void batch(ThreadPool<uint64_t*, uint64_t*>& pool) { | |
std::vector<uint64_t> input{1, 2, 3, 4, 6}; | |
std::vector<uint64_t> output(5); | |
pool.process(foo_dispatcher, &input[0], &output[0]); | |
std::cerr << "Processed: "; | |
for (const auto& o : output) std::cerr << " " << o; | |
std::cerr << std::endl; | |
} | |
in main(int, char**) { | |
ThreadPool<uint64_t*, uint64_t*> pool; | |
// Wait for startup | |
std::this_thread::sleep_for(std::chrono::microseconds(1)); | |
batch(pool); | |
batch(pool); | |
batch(pool); | |
std::this_thread::sleep_for(std::chrono::microseconds(1)); | |
pool.terminate(); | |
return 0; | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment