Skip to content

Instantly share code, notes, and snippets.

@grafi-tt
Created April 24, 2019 13:11
Show Gist options
  • Save grafi-tt/cb90b7cbfee8dffb34c62be46df1f48e to your computer and use it in GitHub Desktop.
Save grafi-tt/cb90b7cbfee8dffb34c62be46df1f48e to your computer and use it in GitHub Desktop.
diff --git a/chainerx_cc/chainerx/native/reduce.h b/chainerx_cc/chainerx/native/reduce.h
index 2b61319de..73db5bad2 100644
--- a/chainerx_cc/chainerx/native/reduce.h
+++ b/chainerx_cc/chainerx/native/reduce.h
@@ -11,19 +11,83 @@ namespace chainerx {
namespace native {
namespace reduce_detail {
+constexpr int64_t ExpandLen = 8;
+constexpr int64_t SerialLen = 16;
+
+template <typename In, typename ReductionImpl, int8_t InNdim, typename T, int64_t n>
+struct ExpandedPairwiseReduction {
+ T run(IndexIterator<InNdim>& it_in, int64_t& i_reduce, ReductionImpl&& impl) {
+ auto accum = ExpandedPairwiseReduction<In, ReductionImpl, InNdim, T, n / 2>::run(it_in, i_reduce, impl);
+ impl.Reduce(ExpandedPairwiseReduction<In, ReductionImpl, InNdim, T, n / 2>::run(it_in, i_reduce, impl), accum);
+ return accum;
+ }
+};
+
+template <typename In, typename ReductionImpl, int8_t InNdim, typename T>
+struct ExpandedPairwiseReduction<In, ReductionImpl, InNdim, 1> {
+ T run(IndexIterator<InNdim>& it_in, int64_t& i_reduce, ReductionImpl&& impl) {
+ return impl.MapIn(native_internal::StorageToDataType<const In>(arg.in[it_in++]), i_reduce++);
+ }
+};
+
+template <typename In, typename ReductionImpl, int8_t InNdim, typename T>
+T PairwiseReduction(IndexIterator<InNdim>& it_in, int64_t reduce_len, std::vector<T>& tree_accum, ReductionImpl&& impl) {
+ int64_t i_reduce = 0;
+ auto accum = impl.Identity();
+
+ bool first_loop = true;
+ while (i_reduce < reduce_len & -ExpandLen) {
+ if (first_loop) {
+ first_loop = false;
+ } else if (i_reduce & SerialLen * ExpandLen - 1 == 0) {
+ int i = 0;
+ int64_t i_reduce_tmp = i_reduce;
+ do {
+ impl.Reduce(tree_accum[i], accum);
+ tree_accum[i] = impl.Identity();
+ ++i, i_reduce_tmp >>= 1;
+ } while (i_reduce_tmp & SerialLen * ExpandLen - 1 == 0);
+ tree_accum[i] = accum;
+ accum = impl.Identity();
+ }
+ impl.Reduce(ExpandedPairwiseReduction<In, ReductionImpl, InNdim, T, ExpandLen>::run(it_in, i_reduce, impl), accum);
+ }
+
+ while (i_reduce < reduce_len) {
+ impl.Reduce(impl.MapIn(native_internal::StorageToDataType<const In>(arg.in[it_in++]), i_reduce++), accum);
+ }
+
+ for (T& leaf_accum : tree_accum) {
+ impl.Reduce(leaf_accum, accum);
+ leaf_accum = impl.Identity();
+ }
+ return accum;
+}
+
+inline int bits_of_index(int64_t n) {
+ if (n <= 0) return 0;
+ --n;
+ int64_t t;
+ int bits = 0;
+ if ((t = n >> 32)) bits += 32, n = t;
+ if ((t = n >> 16)) bits += 16, n = t;
+ if ((t = n >> 8)) bits += 8, n = t;
+ if ((t = n >> 4)) bits += 4, n = t;
+ if ((t = n >> 2)) bits += 2, n = t;
+ bits += static_cast<int>(n); // n is 0 or 1
+ return bits;
+}
+
template <typename In, typename Out, typename ReductionImpl, int8_t InNdim = kDynamicNdim, int8_t OutNdim = kDynamicNdim>
void ReductionKernel(ReductionKernelArg<In, Out, InNdim, OutNdim> arg, ReductionImpl&& impl) {
auto it_in = arg.in_indexer.It(0, arg.out_indexer.total_size());
+ int64_t reduce_len = arg.out_indexer.total_size() / arg.in_indexer.total_size();
+ std::vector<decltype(impl.Identity())> tree_accum(bits_of_index(reduce_len), impl.Identity());
// Iterate over output dimensions
for (auto it_out = arg.out_indexer.It(0); it_out; ++it_out) {
- auto accum = impl.Identity();
-
- int64_t i_reduce{0};
- for (it_in.Restart(it_out.raw_index()); it_in; ++it_in, ++i_reduce) {
- impl.Reduce(impl.MapIn(native_internal::StorageToDataType<const In>(arg.in[it_in]), i_reduce), accum);
- }
-
+ it_in.Restart(it_out.raw_index());
+ auto accum = PairwiseReduction<In, ReductionImpl, InNdim, decltype(impl.Identity())>(it_in, reduce_len, tree_accum, impl);
arg.out[it_out] = native_internal::DataToStorageType<Out>(impl.MapOut(accum));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment