Created
July 27, 2015 14:57
-
-
Save xinhuang/16e03ec6d560df5ca03c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#pragma warning(disable : 4503) | |
#include <cassert> | |
#include <functional> | |
#include <algorithm> | |
#include <tbb/tbb.h> | |
template <size_t N, typename T> struct RepeatTuple {}; | |
template <typename T> struct RepeatTuple<1, T> { | |
typedef tbb::flow::tuple<T> type; | |
}; | |
template <typename T> struct RepeatTuple<2, T> { | |
typedef tbb::flow::tuple<T, T> type; | |
}; | |
template <typename T> struct RepeatTuple<3, T> { | |
typedef tbb::flow::tuple<T, T, T> type; | |
}; | |
template <typename T> struct RepeatTuple<4, T> { | |
typedef tbb::flow::tuple<T, T, T, T> type; | |
}; | |
template <typename T> struct RepeatTuple<5, T> { | |
typedef tbb::flow::tuple<T, T, T, T, T> type; | |
}; | |
template <typename T> struct RepeatTuple<6, T> { | |
typedef tbb::flow::tuple<T, T, T, T, T, T> type; | |
}; | |
template <typename T> struct RepeatTuple<7, T> { | |
typedef tbb::flow::tuple<T, T, T, T, T, T, T> type; | |
}; | |
template <typename T> struct RepeatTuple<8, T> { | |
typedef tbb::flow::tuple<T, T, T, T, T, T, T, T> type; | |
}; | |
template <typename T> struct RepeatTuple<9, T> { | |
typedef tbb::flow::tuple<T, T, T, T, T, T, T, T, T> type; | |
}; | |
template <typename T> struct RepeatTuple<10, T> { | |
typedef tbb::flow::tuple<T, T, T, T, T, T, T, T, T, T> type; | |
}; | |
template <typename T> const T &unpack_min(const T &v) { return v; } | |
template <typename T, typename... Ts> | |
const T &unpack_min(const T &first, const Ts &... rest) { | |
return std::min(first, unpack_min(rest...)); | |
} | |
struct tuple_min { | |
template <typename T> const T &operator()(const std::tuple<T> &t) const { | |
return std::get<0>(t); | |
} | |
template <typename T> const T &operator()(const std::tuple<T, T> &t) const { | |
return unpack_min(std::get<0>(t), std::get<1>(t)); | |
} | |
template <typename T> | |
const T &operator()(const std::tuple<T, T, T> &t) const { | |
return unpack_min(std::get<0>(t), std::get<1>(t), std::get<2>(t)); | |
} | |
template <typename T> | |
const T &operator()(const std::tuple<T, T, T, T> &t) const { | |
return unpack_min(std::get<0>(t), std::get<1>(t), std::get<2>(t), | |
std::get<3>(t)); | |
} | |
template <typename T> | |
const T &operator()(const std::tuple<T, T, T, T, T> &t) const { | |
return unpack_min(std::get<0>(t), std::get<1>(t), std::get<2>(t), | |
std::get<3>(t), std::get<4>(t)); | |
} | |
template <typename T> | |
const T &operator()(const std::tuple<T, T, T, T, T, T> &t) const { | |
return unpack_min(std::get<0>(t), std::get<1>(t), std::get<2>(t), | |
std::get<3>(t), std::get<4>(t), std::get<5>(t)); | |
} | |
template <typename T> | |
const T &operator()(const std::tuple<T, T, T, T, T, T, T> &t) const { | |
return unpack_min(std::get<0>(t), std::get<1>(t), std::get<2>(t), | |
std::get<3>(t), std::get<4>(t), std::get<5>(t), | |
std::get<6>(t)); | |
} | |
template <typename T> | |
const T &operator()(const std::tuple<T, T, T, T, T, T, T, T> &t) const { | |
return unpack_min(std::get<0>(t), std::get<1>(t), std::get<2>(t), | |
std::get<3>(t), std::get<4>(t), std::get<5>(t), | |
std::get<6>(t), std::get<7>(t)); | |
} | |
template <typename T> | |
const T &operator()(const std::tuple<T, T, T, T, T, T, T, T, T> &t) const { | |
return unpack_min(std::get<0>(t), std::get<1>(t), std::get<2>(t), | |
std::get<3>(t), std::get<4>(t), std::get<5>(t), | |
std::get<6>(t), std::get<7>(t), std::get<8>(t)); | |
} | |
template <typename T> | |
const T &operator()(const std::tuple<T, T, T, T, T, T, T, T, T, T> &t) const { | |
return unpack_min(std::get<0>(t), std::get<1>(t), std::get<2>(t), | |
std::get<3>(t), std::get<4>(t), std::get<5>(t), | |
std::get<6>(t), std::get<7>(t), std::get<8>(t), | |
std::get<9>(t)); | |
} | |
}; | |
template <typename T, typename M = tuple_min> class merge_node { | |
public: | |
typedef T input_type; | |
typedef T output_type; | |
typedef std::function<std::tuple<bool, T>(T)> func_t; | |
private: | |
size_t nary; | |
void *fnode = nullptr; | |
void *jnode = nullptr; | |
static void *create_fnode(tbb::flow::graph &g, size_t nary, | |
const func_t &func); | |
static void *create_jnode_with_edge(tbb::flow::graph &g, size_t nary, | |
void *pfnode); | |
public: | |
template <size_t N> struct input_ports_tuple { | |
typedef typename RepeatTuple<N, T>::type type; | |
}; | |
template <typename Body> | |
merge_node(tbb::flow::graph &g, const Body &body, size_t nary) | |
: nary(nary) { | |
using namespace tbb::flow; | |
assert(nary > 0); | |
if (nary > __TBB_VARIADIC_MAX) { | |
throw std::runtime_error( | |
"Too many upstream nodes. Please check __TBB_VARIADIC_MAX"); | |
} | |
if (nary > 1) { | |
fnode = create_fnode(g, nary, body); | |
jnode = create_jnode_with_edge(g, nary, fnode); | |
} else { | |
fnode = create_fnode(g, nary, body); | |
} | |
} | |
~merge_node() { | |
using namespace tbb::flow; | |
switch (nary) { | |
case 1: | |
break; | |
#define CASE_NARY(N) \ | |
case N: \ | |
delete reinterpret_cast< \ | |
multifunction_node<input_ports_tuple<N>::type, tuple<T>> *>(fnode); \ | |
delete reinterpret_cast<join_node<input_ports_tuple<N>::type> *>(jnode); \ | |
break; | |
CASE_NARY(2); | |
CASE_NARY(3); | |
CASE_NARY(4); | |
CASE_NARY(5); | |
CASE_NARY(6); | |
CASE_NARY(7); | |
CASE_NARY(8); | |
CASE_NARY(9); | |
#undef CASE_NARY | |
default: | |
assert(false); | |
} | |
} | |
bool try_put(const T &r) { | |
assert(nary == 1); | |
return reinterpret_cast< | |
tbb::flow::multifunction_node<T, tbb::flow::tuple<T>> *>(fnode) | |
->try_put(r); | |
} | |
size_t get_nary() const { return nary; } | |
template <size_t K, size_t N, typename T, | |
typename U = typename T::input_ports_tuple<N>> | |
friend typename tbb::flow::tuple_element< | |
K, | |
typename tbb::flow::join_node<typename U::type>::input_ports_type>::type & | |
input_port(T &n) { | |
assert(N == n.get_nary()); | |
assert(K < N); | |
typedef U::type tuple_type; | |
return tbb::flow::input_port<K>( | |
*reinterpret_cast<tbb::flow::join_node<tuple_type> *>(n.jnode)); | |
} | |
static void make_edge(merge_node &src, merge_node &dest, size_t slot); | |
static void remove_edge(merge_node &src, merge_node &dest, size_t slot); | |
}; | |
template <typename T> inline void make_edge(T &src, T &dest, size_t slot) { | |
T::make_edge(src, dest, slot); | |
} | |
template <typename T> inline void remove_edge(T &src, T &dest, size_t slot) { | |
T::remove_edge(src, dest, slot); | |
} | |
namespace { | |
template <typename T, typename R> | |
void *create_jnode_with_edge(tbb::flow::graph &g, void *ptr) { | |
using namespace tbb::flow; | |
join_node<T> *jnode = new join_node<T>(g); | |
multifunction_node<T, tuple<R>> *fnode = | |
reinterpret_cast<multifunction_node<T, tuple<R>> *>(ptr); | |
make_edge(*jnode, *fnode); | |
return jnode; | |
} | |
} | |
template <typename T, typename M> | |
void merge_node<T, M>::make_edge(merge_node<T, M> &src, merge_node<T, M> &dest, | |
size_t slot) { | |
using namespace tbb::flow; | |
size_t dest_nary = dest.get_nary(); | |
assert(slot < dest_nary); | |
assert(src.fnode); | |
assert(dest.fnode); | |
assert(dest_nary > 1 && dest.jnode || | |
dest_nary <= 1 && dest.jnode == nullptr); | |
switch (dest_nary) { | |
case 1: | |
tbb::flow::make_edge( | |
output_port<0>( | |
*reinterpret_cast<multifunction_node<T, tuple<T>> *>(src.fnode)), | |
*reinterpret_cast<multifunction_node<T, tuple<T>> *>(dest.fnode)); | |
return; | |
#define CONNECT_FNODE_TO_PORT(K, N) \ | |
tbb::flow::make_edge( \ | |
output_port<0>(*reinterpret_cast<multifunction_node< \ | |
typename RepeatTuple<N, T>::type, tuple<T>> *>(src.fnode)), \ | |
input_port<K, N>(dest)); | |
#define CASE_SLOT(K, N) \ | |
case K: \ | |
CONNECT_FNODE_TO_PORT(K, N); \ | |
return; | |
#define CASE_SLOT_DEFAULT() \ | |
default: \ | |
throw std::runtime_error("unsupported nary"); | |
case 2: | |
switch (slot) { | |
CASE_SLOT(0, 2); | |
CASE_SLOT(1, 2); | |
CASE_SLOT_DEFAULT(); | |
} | |
case 3: | |
switch (slot) { | |
CASE_SLOT(0, 3); | |
CASE_SLOT(1, 3); | |
CASE_SLOT(2, 3); | |
CASE_SLOT_DEFAULT(); | |
} | |
case 4: | |
switch (slot) { | |
CASE_SLOT(0, 4); | |
CASE_SLOT(1, 4); | |
CASE_SLOT(2, 4); | |
CASE_SLOT(3, 4); | |
CASE_SLOT_DEFAULT(); | |
} | |
#if __TBB_VARIADIC_MAX >= 10 | |
case 5: | |
switch (slot) { | |
CASE_SLOT(0, 5); | |
CASE_SLOT(1, 5); | |
CASE_SLOT(2, 5); | |
CASE_SLOT(3, 5); | |
CASE_SLOT(4, 5); | |
CASE_SLOT_DEFAULT(); | |
} | |
case 6: | |
switch (slot) { | |
CASE_SLOT(0, 6); | |
CASE_SLOT(1, 6); | |
CASE_SLOT(2, 6); | |
CASE_SLOT(3, 6); | |
CASE_SLOT(4, 6); | |
CASE_SLOT(5, 6); | |
CASE_SLOT_DEFAULT(); | |
} | |
case 7: | |
switch (slot) { | |
CASE_SLOT(0, 7); | |
CASE_SLOT(1, 7); | |
CASE_SLOT(2, 7); | |
CASE_SLOT(3, 7); | |
CASE_SLOT(4, 7); | |
CASE_SLOT(5, 7); | |
CASE_SLOT(6, 7); | |
CASE_SLOT_DEFAULT(); | |
} | |
case 8: | |
switch (slot) { | |
CASE_SLOT(0, 8); | |
CASE_SLOT(1, 8); | |
CASE_SLOT(2, 8); | |
CASE_SLOT(3, 8); | |
CASE_SLOT(4, 8); | |
CASE_SLOT(5, 8); | |
CASE_SLOT(6, 8); | |
CASE_SLOT(7, 8); | |
CASE_SLOT_DEFAULT(); | |
} | |
case 9: | |
switch (slot) { | |
CASE_SLOT(0, 9); | |
CASE_SLOT(1, 9); | |
CASE_SLOT(2, 9); | |
CASE_SLOT(3, 9); | |
CASE_SLOT(4, 9); | |
CASE_SLOT(5, 9); | |
CASE_SLOT(6, 9); | |
CASE_SLOT(7, 9); | |
CASE_SLOT(8, 9); | |
CASE_SLOT_DEFAULT(); | |
} | |
case 10: | |
switch (slot) { | |
CASE_SLOT(0, 10); | |
CASE_SLOT(1, 10); | |
CASE_SLOT(2, 10); | |
CASE_SLOT(3, 10); | |
CASE_SLOT(4, 10); | |
CASE_SLOT(5, 10); | |
CASE_SLOT(6, 10); | |
CASE_SLOT(7, 10); | |
CASE_SLOT(8, 10); | |
CASE_SLOT(9, 10); | |
CASE_SLOT_DEFAULT(); | |
} | |
#endif __TBB_VARIADIC_MAX | |
default: | |
throw std::runtime_error("unsupported nary"); | |
} | |
throw std::runtime_error("unsupported nary"); | |
#undef CASE_SLOT | |
#undef CONNECT_FNODE_TO_PORT | |
} | |
template <typename T, typename M> | |
void merge_node<T, M>::remove_edge(merge_node<T, M> &src, | |
merge_node<T, M> &dest, size_t slot) { | |
using namespace tbb::flow; | |
size_t dest_nary = dest.get_nary(); | |
assert(slot < dest_nary); | |
assert(src.fnode); | |
assert(dest.fnode); | |
assert(dest_nary > 1 && dest.jnode || | |
dest_nary <= 1 && dest.jnode == nullptr); | |
switch (dest_nary) { | |
case 1: | |
tbb::flow::remove_edge( | |
output_port<0>( | |
*reinterpret_cast<multifunction_node<T, tuple<T>> *>(src.fnode)), | |
*reinterpret_cast<multifunction_node<T, tuple<T>> *>(dest.fnode)); | |
return; | |
#define CONNECT_FNODE_TO_PORT(K, N) \ | |
tbb::flow::remove_edge( \ | |
output_port<0>(*reinterpret_cast<multifunction_node< \ | |
typename RepeatTuple<N, T>::type, tuple<T>> *>(src.fnode)), \ | |
input_port<K, N>(dest)); | |
#define CASE_SLOT(K, N) \ | |
case K: \ | |
CONNECT_FNODE_TO_PORT(K, N); \ | |
return; | |
#define CASE_SLOT_DEFAULT() \ | |
default: \ | |
throw std::runtime_error("unsupported nary"); | |
case 2: | |
switch (slot) { | |
CASE_SLOT(0, 2); | |
CASE_SLOT(1, 2); | |
CASE_SLOT_DEFAULT(); | |
} | |
case 3: | |
switch (slot) { | |
CASE_SLOT(0, 3); | |
CASE_SLOT(1, 3); | |
CASE_SLOT(2, 3); | |
CASE_SLOT_DEFAULT(); | |
} | |
case 4: | |
switch (slot) { | |
CASE_SLOT(0, 4); | |
CASE_SLOT(1, 4); | |
CASE_SLOT(2, 4); | |
CASE_SLOT(3, 4); | |
CASE_SLOT_DEFAULT(); | |
} | |
#if __TBB_VARIADIC_MAX >= 10 | |
case 5: | |
switch (slot) { | |
CASE_SLOT(0, 5); | |
CASE_SLOT(1, 5); | |
CASE_SLOT(2, 5); | |
CASE_SLOT(3, 5); | |
CASE_SLOT(4, 5); | |
CASE_SLOT_DEFAULT(); | |
} | |
case 6: | |
switch (slot) { | |
CASE_SLOT(0, 6); | |
CASE_SLOT(1, 6); | |
CASE_SLOT(2, 6); | |
CASE_SLOT(3, 6); | |
CASE_SLOT(4, 6); | |
CASE_SLOT(5, 6); | |
CASE_SLOT_DEFAULT(); | |
} | |
case 7: | |
switch (slot) { | |
CASE_SLOT(0, 7); | |
CASE_SLOT(1, 7); | |
CASE_SLOT(2, 7); | |
CASE_SLOT(3, 7); | |
CASE_SLOT(4, 7); | |
CASE_SLOT(5, 7); | |
CASE_SLOT(6, 7); | |
CASE_SLOT_DEFAULT(); | |
} | |
case 8: | |
switch (slot) { | |
CASE_SLOT(0, 8); | |
CASE_SLOT(1, 8); | |
CASE_SLOT(2, 8); | |
CASE_SLOT(3, 8); | |
CASE_SLOT(4, 8); | |
CASE_SLOT(5, 8); | |
CASE_SLOT(6, 8); | |
CASE_SLOT(7, 8); | |
CASE_SLOT_DEFAULT(); | |
} | |
case 9: | |
switch (slot) { | |
CASE_SLOT(0, 9); | |
CASE_SLOT(1, 9); | |
CASE_SLOT(2, 9); | |
CASE_SLOT(3, 9); | |
CASE_SLOT(4, 9); | |
CASE_SLOT(5, 9); | |
CASE_SLOT(6, 9); | |
CASE_SLOT(7, 9); | |
CASE_SLOT(8, 9); | |
CASE_SLOT_DEFAULT(); | |
} | |
case 10: | |
switch (slot) { | |
CASE_SLOT(0, 10); | |
CASE_SLOT(1, 10); | |
CASE_SLOT(2, 10); | |
CASE_SLOT(3, 10); | |
CASE_SLOT(4, 10); | |
CASE_SLOT(5, 10); | |
CASE_SLOT(6, 10); | |
CASE_SLOT(7, 10); | |
CASE_SLOT(8, 10); | |
CASE_SLOT(9, 10); | |
CASE_SLOT_DEFAULT(); | |
} | |
#endif __TBB_VARIADIC_MAX | |
default: | |
throw std::runtime_error("unsupported nary"); | |
} | |
throw std::runtime_error("unsupported nary"); | |
#undef CASE_SLOT | |
#undef CONNECT_FNODE_TO_PORT | |
} | |
template <typename T, typename M> | |
void *merge_node<T, M>::create_fnode(tbb::flow::graph &g, size_t nary, | |
const func_t &func) { | |
using namespace tbb::flow; | |
assert(nary >= 0); | |
#define CASE_NARY(N) \ | |
case N: { \ | |
typedef input_ports_tuple<N>::type TR; \ | |
multifunction_node<TR, tuple<T>> *fnode = \ | |
new multifunction_node<TR, tuple<T>>( \ | |
g, 1, \ | |
[=](const TR &t, \ | |
multifunction_node<TR, tuple<T>>::output_ports_type &op) { \ | |
M m; \ | |
const auto &r = func(m(t)); \ | |
if (std::get<0>(r)) \ | |
std::get<0>(op).try_put(std::get<1>(r)); \ | |
}); \ | |
return fnode; \ | |
} | |
switch (nary) { | |
CASE_NARY(1); | |
CASE_NARY(2); | |
CASE_NARY(3); | |
CASE_NARY(4); | |
CASE_NARY(5); | |
CASE_NARY(6); | |
CASE_NARY(7); | |
CASE_NARY(8); | |
CASE_NARY(9); | |
CASE_NARY(10); | |
default: | |
throw std::runtime_error("unsupported nary"); | |
} | |
#undef CASE_NARY | |
} | |
template <typename T, typename M> | |
void *merge_node<T, M>::create_jnode_with_edge(tbb::flow::graph &g, size_t nary, | |
void *ptr) { | |
assert(nary > 1); | |
#define CASE_NARY(N) \ | |
case N: \ | |
return ::create_jnode_with_edge<input_ports_tuple<N>::type, T>(g, ptr); | |
switch (nary) { | |
CASE_NARY(2); | |
CASE_NARY(3); | |
CASE_NARY(4); | |
CASE_NARY(5); | |
#if __TBB_VARIADIC_MAX >= 10 | |
CASE_NARY(6); | |
CASE_NARY(7); | |
CASE_NARY(8); | |
CASE_NARY(9); | |
CASE_NARY(10); | |
#endif // __TBB_VARIADIC_MAX | |
default: | |
throw std::runtime_error("unsupported nary"); | |
} | |
#undef CASE_NARY | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment