Skip to content

Instantly share code, notes, and snippets.

@xinhuang
Created July 27, 2015 14:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xinhuang/16e03ec6d560df5ca03c to your computer and use it in GitHub Desktop.
Save xinhuang/16e03ec6d560df5ca03c to your computer and use it in GitHub Desktop.
#pragma once
#pragma warning(disable : 4503)
#include <cassert>
#include <functional>
#include <algorithm>
#include <tbb/tbb.h>
template <size_t N, typename T> struct RepeatTuple {};
template <typename T> struct RepeatTuple<1, T> {
typedef tbb::flow::tuple<T> type;
};
template <typename T> struct RepeatTuple<2, T> {
typedef tbb::flow::tuple<T, T> type;
};
template <typename T> struct RepeatTuple<3, T> {
typedef tbb::flow::tuple<T, T, T> type;
};
template <typename T> struct RepeatTuple<4, T> {
typedef tbb::flow::tuple<T, T, T, T> type;
};
template <typename T> struct RepeatTuple<5, T> {
typedef tbb::flow::tuple<T, T, T, T, T> type;
};
template <typename T> struct RepeatTuple<6, T> {
typedef tbb::flow::tuple<T, T, T, T, T, T> type;
};
template <typename T> struct RepeatTuple<7, T> {
typedef tbb::flow::tuple<T, T, T, T, T, T, T> type;
};
template <typename T> struct RepeatTuple<8, T> {
typedef tbb::flow::tuple<T, T, T, T, T, T, T, T> type;
};
template <typename T> struct RepeatTuple<9, T> {
typedef tbb::flow::tuple<T, T, T, T, T, T, T, T, T> type;
};
template <typename T> struct RepeatTuple<10, T> {
typedef tbb::flow::tuple<T, T, T, T, T, T, T, T, T, T> type;
};
template <typename T> const T &unpack_min(const T &v) { return v; }
template <typename T, typename... Ts>
const T &unpack_min(const T &first, const Ts &... rest) {
return std::min(first, unpack_min(rest...));
}
struct tuple_min {
template <typename T> const T &operator()(const std::tuple<T> &t) const {
return std::get<0>(t);
}
template <typename T> const T &operator()(const std::tuple<T, T> &t) const {
return unpack_min(std::get<0>(t), std::get<1>(t));
}
template <typename T>
const T &operator()(const std::tuple<T, T, T> &t) const {
return unpack_min(std::get<0>(t), std::get<1>(t), std::get<2>(t));
}
template <typename T>
const T &operator()(const std::tuple<T, T, T, T> &t) const {
return unpack_min(std::get<0>(t), std::get<1>(t), std::get<2>(t),
std::get<3>(t));
}
template <typename T>
const T &operator()(const std::tuple<T, T, T, T, T> &t) const {
return unpack_min(std::get<0>(t), std::get<1>(t), std::get<2>(t),
std::get<3>(t), std::get<4>(t));
}
template <typename T>
const T &operator()(const std::tuple<T, T, T, T, T, T> &t) const {
return unpack_min(std::get<0>(t), std::get<1>(t), std::get<2>(t),
std::get<3>(t), std::get<4>(t), std::get<5>(t));
}
template <typename T>
const T &operator()(const std::tuple<T, T, T, T, T, T, T> &t) const {
return unpack_min(std::get<0>(t), std::get<1>(t), std::get<2>(t),
std::get<3>(t), std::get<4>(t), std::get<5>(t),
std::get<6>(t));
}
template <typename T>
const T &operator()(const std::tuple<T, T, T, T, T, T, T, T> &t) const {
return unpack_min(std::get<0>(t), std::get<1>(t), std::get<2>(t),
std::get<3>(t), std::get<4>(t), std::get<5>(t),
std::get<6>(t), std::get<7>(t));
}
template <typename T>
const T &operator()(const std::tuple<T, T, T, T, T, T, T, T, T> &t) const {
return unpack_min(std::get<0>(t), std::get<1>(t), std::get<2>(t),
std::get<3>(t), std::get<4>(t), std::get<5>(t),
std::get<6>(t), std::get<7>(t), std::get<8>(t));
}
template <typename T>
const T &operator()(const std::tuple<T, T, T, T, T, T, T, T, T, T> &t) const {
return unpack_min(std::get<0>(t), std::get<1>(t), std::get<2>(t),
std::get<3>(t), std::get<4>(t), std::get<5>(t),
std::get<6>(t), std::get<7>(t), std::get<8>(t),
std::get<9>(t));
}
};
template <typename T, typename M = tuple_min> class merge_node {
public:
typedef T input_type;
typedef T output_type;
typedef std::function<std::tuple<bool, T>(T)> func_t;
private:
size_t nary;
void *fnode = nullptr;
void *jnode = nullptr;
static void *create_fnode(tbb::flow::graph &g, size_t nary,
const func_t &func);
static void *create_jnode_with_edge(tbb::flow::graph &g, size_t nary,
void *pfnode);
public:
template <size_t N> struct input_ports_tuple {
typedef typename RepeatTuple<N, T>::type type;
};
template <typename Body>
merge_node(tbb::flow::graph &g, const Body &body, size_t nary)
: nary(nary) {
using namespace tbb::flow;
assert(nary > 0);
if (nary > __TBB_VARIADIC_MAX) {
throw std::runtime_error(
"Too many upstream nodes. Please check __TBB_VARIADIC_MAX");
}
if (nary > 1) {
fnode = create_fnode(g, nary, body);
jnode = create_jnode_with_edge(g, nary, fnode);
} else {
fnode = create_fnode(g, nary, body);
}
}
~merge_node() {
using namespace tbb::flow;
switch (nary) {
case 1:
break;
#define CASE_NARY(N) \
case N: \
delete reinterpret_cast< \
multifunction_node<input_ports_tuple<N>::type, tuple<T>> *>(fnode); \
delete reinterpret_cast<join_node<input_ports_tuple<N>::type> *>(jnode); \
break;
CASE_NARY(2);
CASE_NARY(3);
CASE_NARY(4);
CASE_NARY(5);
CASE_NARY(6);
CASE_NARY(7);
CASE_NARY(8);
CASE_NARY(9);
#undef CASE_NARY
default:
assert(false);
}
}
bool try_put(const T &r) {
assert(nary == 1);
return reinterpret_cast<
tbb::flow::multifunction_node<T, tbb::flow::tuple<T>> *>(fnode)
->try_put(r);
}
size_t get_nary() const { return nary; }
template <size_t K, size_t N, typename T,
typename U = typename T::input_ports_tuple<N>>
friend typename tbb::flow::tuple_element<
K,
typename tbb::flow::join_node<typename U::type>::input_ports_type>::type &
input_port(T &n) {
assert(N == n.get_nary());
assert(K < N);
typedef U::type tuple_type;
return tbb::flow::input_port<K>(
*reinterpret_cast<tbb::flow::join_node<tuple_type> *>(n.jnode));
}
static void make_edge(merge_node &src, merge_node &dest, size_t slot);
static void remove_edge(merge_node &src, merge_node &dest, size_t slot);
};
template <typename T> inline void make_edge(T &src, T &dest, size_t slot) {
T::make_edge(src, dest, slot);
}
template <typename T> inline void remove_edge(T &src, T &dest, size_t slot) {
T::remove_edge(src, dest, slot);
}
namespace {
template <typename T, typename R>
void *create_jnode_with_edge(tbb::flow::graph &g, void *ptr) {
using namespace tbb::flow;
join_node<T> *jnode = new join_node<T>(g);
multifunction_node<T, tuple<R>> *fnode =
reinterpret_cast<multifunction_node<T, tuple<R>> *>(ptr);
make_edge(*jnode, *fnode);
return jnode;
}
}
template <typename T, typename M>
void merge_node<T, M>::make_edge(merge_node<T, M> &src, merge_node<T, M> &dest,
size_t slot) {
using namespace tbb::flow;
size_t dest_nary = dest.get_nary();
assert(slot < dest_nary);
assert(src.fnode);
assert(dest.fnode);
assert(dest_nary > 1 && dest.jnode ||
dest_nary <= 1 && dest.jnode == nullptr);
switch (dest_nary) {
case 1:
tbb::flow::make_edge(
output_port<0>(
*reinterpret_cast<multifunction_node<T, tuple<T>> *>(src.fnode)),
*reinterpret_cast<multifunction_node<T, tuple<T>> *>(dest.fnode));
return;
#define CONNECT_FNODE_TO_PORT(K, N) \
tbb::flow::make_edge( \
output_port<0>(*reinterpret_cast<multifunction_node< \
typename RepeatTuple<N, T>::type, tuple<T>> *>(src.fnode)), \
input_port<K, N>(dest));
#define CASE_SLOT(K, N) \
case K: \
CONNECT_FNODE_TO_PORT(K, N); \
return;
#define CASE_SLOT_DEFAULT() \
default: \
throw std::runtime_error("unsupported nary");
case 2:
switch (slot) {
CASE_SLOT(0, 2);
CASE_SLOT(1, 2);
CASE_SLOT_DEFAULT();
}
case 3:
switch (slot) {
CASE_SLOT(0, 3);
CASE_SLOT(1, 3);
CASE_SLOT(2, 3);
CASE_SLOT_DEFAULT();
}
case 4:
switch (slot) {
CASE_SLOT(0, 4);
CASE_SLOT(1, 4);
CASE_SLOT(2, 4);
CASE_SLOT(3, 4);
CASE_SLOT_DEFAULT();
}
#if __TBB_VARIADIC_MAX >= 10
case 5:
switch (slot) {
CASE_SLOT(0, 5);
CASE_SLOT(1, 5);
CASE_SLOT(2, 5);
CASE_SLOT(3, 5);
CASE_SLOT(4, 5);
CASE_SLOT_DEFAULT();
}
case 6:
switch (slot) {
CASE_SLOT(0, 6);
CASE_SLOT(1, 6);
CASE_SLOT(2, 6);
CASE_SLOT(3, 6);
CASE_SLOT(4, 6);
CASE_SLOT(5, 6);
CASE_SLOT_DEFAULT();
}
case 7:
switch (slot) {
CASE_SLOT(0, 7);
CASE_SLOT(1, 7);
CASE_SLOT(2, 7);
CASE_SLOT(3, 7);
CASE_SLOT(4, 7);
CASE_SLOT(5, 7);
CASE_SLOT(6, 7);
CASE_SLOT_DEFAULT();
}
case 8:
switch (slot) {
CASE_SLOT(0, 8);
CASE_SLOT(1, 8);
CASE_SLOT(2, 8);
CASE_SLOT(3, 8);
CASE_SLOT(4, 8);
CASE_SLOT(5, 8);
CASE_SLOT(6, 8);
CASE_SLOT(7, 8);
CASE_SLOT_DEFAULT();
}
case 9:
switch (slot) {
CASE_SLOT(0, 9);
CASE_SLOT(1, 9);
CASE_SLOT(2, 9);
CASE_SLOT(3, 9);
CASE_SLOT(4, 9);
CASE_SLOT(5, 9);
CASE_SLOT(6, 9);
CASE_SLOT(7, 9);
CASE_SLOT(8, 9);
CASE_SLOT_DEFAULT();
}
case 10:
switch (slot) {
CASE_SLOT(0, 10);
CASE_SLOT(1, 10);
CASE_SLOT(2, 10);
CASE_SLOT(3, 10);
CASE_SLOT(4, 10);
CASE_SLOT(5, 10);
CASE_SLOT(6, 10);
CASE_SLOT(7, 10);
CASE_SLOT(8, 10);
CASE_SLOT(9, 10);
CASE_SLOT_DEFAULT();
}
#endif __TBB_VARIADIC_MAX
default:
throw std::runtime_error("unsupported nary");
}
throw std::runtime_error("unsupported nary");
#undef CASE_SLOT
#undef CONNECT_FNODE_TO_PORT
}
template <typename T, typename M>
void merge_node<T, M>::remove_edge(merge_node<T, M> &src,
merge_node<T, M> &dest, size_t slot) {
using namespace tbb::flow;
size_t dest_nary = dest.get_nary();
assert(slot < dest_nary);
assert(src.fnode);
assert(dest.fnode);
assert(dest_nary > 1 && dest.jnode ||
dest_nary <= 1 && dest.jnode == nullptr);
switch (dest_nary) {
case 1:
tbb::flow::remove_edge(
output_port<0>(
*reinterpret_cast<multifunction_node<T, tuple<T>> *>(src.fnode)),
*reinterpret_cast<multifunction_node<T, tuple<T>> *>(dest.fnode));
return;
#define CONNECT_FNODE_TO_PORT(K, N) \
tbb::flow::remove_edge( \
output_port<0>(*reinterpret_cast<multifunction_node< \
typename RepeatTuple<N, T>::type, tuple<T>> *>(src.fnode)), \
input_port<K, N>(dest));
#define CASE_SLOT(K, N) \
case K: \
CONNECT_FNODE_TO_PORT(K, N); \
return;
#define CASE_SLOT_DEFAULT() \
default: \
throw std::runtime_error("unsupported nary");
case 2:
switch (slot) {
CASE_SLOT(0, 2);
CASE_SLOT(1, 2);
CASE_SLOT_DEFAULT();
}
case 3:
switch (slot) {
CASE_SLOT(0, 3);
CASE_SLOT(1, 3);
CASE_SLOT(2, 3);
CASE_SLOT_DEFAULT();
}
case 4:
switch (slot) {
CASE_SLOT(0, 4);
CASE_SLOT(1, 4);
CASE_SLOT(2, 4);
CASE_SLOT(3, 4);
CASE_SLOT_DEFAULT();
}
#if __TBB_VARIADIC_MAX >= 10
case 5:
switch (slot) {
CASE_SLOT(0, 5);
CASE_SLOT(1, 5);
CASE_SLOT(2, 5);
CASE_SLOT(3, 5);
CASE_SLOT(4, 5);
CASE_SLOT_DEFAULT();
}
case 6:
switch (slot) {
CASE_SLOT(0, 6);
CASE_SLOT(1, 6);
CASE_SLOT(2, 6);
CASE_SLOT(3, 6);
CASE_SLOT(4, 6);
CASE_SLOT(5, 6);
CASE_SLOT_DEFAULT();
}
case 7:
switch (slot) {
CASE_SLOT(0, 7);
CASE_SLOT(1, 7);
CASE_SLOT(2, 7);
CASE_SLOT(3, 7);
CASE_SLOT(4, 7);
CASE_SLOT(5, 7);
CASE_SLOT(6, 7);
CASE_SLOT_DEFAULT();
}
case 8:
switch (slot) {
CASE_SLOT(0, 8);
CASE_SLOT(1, 8);
CASE_SLOT(2, 8);
CASE_SLOT(3, 8);
CASE_SLOT(4, 8);
CASE_SLOT(5, 8);
CASE_SLOT(6, 8);
CASE_SLOT(7, 8);
CASE_SLOT_DEFAULT();
}
case 9:
switch (slot) {
CASE_SLOT(0, 9);
CASE_SLOT(1, 9);
CASE_SLOT(2, 9);
CASE_SLOT(3, 9);
CASE_SLOT(4, 9);
CASE_SLOT(5, 9);
CASE_SLOT(6, 9);
CASE_SLOT(7, 9);
CASE_SLOT(8, 9);
CASE_SLOT_DEFAULT();
}
case 10:
switch (slot) {
CASE_SLOT(0, 10);
CASE_SLOT(1, 10);
CASE_SLOT(2, 10);
CASE_SLOT(3, 10);
CASE_SLOT(4, 10);
CASE_SLOT(5, 10);
CASE_SLOT(6, 10);
CASE_SLOT(7, 10);
CASE_SLOT(8, 10);
CASE_SLOT(9, 10);
CASE_SLOT_DEFAULT();
}
#endif __TBB_VARIADIC_MAX
default:
throw std::runtime_error("unsupported nary");
}
throw std::runtime_error("unsupported nary");
#undef CASE_SLOT
#undef CONNECT_FNODE_TO_PORT
}
template <typename T, typename M>
void *merge_node<T, M>::create_fnode(tbb::flow::graph &g, size_t nary,
const func_t &func) {
using namespace tbb::flow;
assert(nary >= 0);
#define CASE_NARY(N) \
case N: { \
typedef input_ports_tuple<N>::type TR; \
multifunction_node<TR, tuple<T>> *fnode = \
new multifunction_node<TR, tuple<T>>( \
g, 1, \
[=](const TR &t, \
multifunction_node<TR, tuple<T>>::output_ports_type &op) { \
M m; \
const auto &r = func(m(t)); \
if (std::get<0>(r)) \
std::get<0>(op).try_put(std::get<1>(r)); \
}); \
return fnode; \
}
switch (nary) {
CASE_NARY(1);
CASE_NARY(2);
CASE_NARY(3);
CASE_NARY(4);
CASE_NARY(5);
CASE_NARY(6);
CASE_NARY(7);
CASE_NARY(8);
CASE_NARY(9);
CASE_NARY(10);
default:
throw std::runtime_error("unsupported nary");
}
#undef CASE_NARY
}
template <typename T, typename M>
void *merge_node<T, M>::create_jnode_with_edge(tbb::flow::graph &g, size_t nary,
void *ptr) {
assert(nary > 1);
#define CASE_NARY(N) \
case N: \
return ::create_jnode_with_edge<input_ports_tuple<N>::type, T>(g, ptr);
switch (nary) {
CASE_NARY(2);
CASE_NARY(3);
CASE_NARY(4);
CASE_NARY(5);
#if __TBB_VARIADIC_MAX >= 10
CASE_NARY(6);
CASE_NARY(7);
CASE_NARY(8);
CASE_NARY(9);
CASE_NARY(10);
#endif // __TBB_VARIADIC_MAX
default:
throw std::runtime_error("unsupported nary");
}
#undef CASE_NARY
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment