Last active
April 17, 2017 17:49
-
-
Save wolfv/890b6597bf70476fb0ebbda1e43c06be to your computer and use it in GitHub Desktop.
Reducers benchmark vs numpy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <numpy/arrayobject.h> | |
#include "pybind11/pybind11.h" | |
#include "pybind11/stl.h" | |
#include "xtensor/xarray.hpp" | |
#include "xtensor/xtensor.hpp" | |
#include "xtensor/xcontainer.hpp" | |
#include "xtensor/xbroadcast.hpp" | |
#include "xtensor/xview.hpp" | |
#include "xtensor/xeval.hpp" | |
#include "xtensor/xio.hpp" | |
#include "xtensor/xstridedview.hpp" | |
#include "xtensor/xexpression.hpp" | |
#include "xtensor-python/pyarray.hpp" | |
#include "xtensor-python/pytensor.hpp" | |
namespace py = pybind11; | |
template<class E1> | |
auto logsumexp1 (const E1& e1) { | |
using value_type = typename std::decay_t<E1>::value_type; | |
auto max = xt::amax (e1)(); | |
return std::move (max + xt::log (xt::sum (xt::exp (e1-max)))); | |
} | |
template <class CT> | |
class stupid_reduce | |
{ | |
public: | |
using xtype = std::decay_t<CT>; | |
stupid_reduce(CT&& e) : m_e(e) | |
{ | |
m_stride = m_e.strides()[1]; | |
m_shape = {m_e.shape()[0]}; | |
} | |
xtype reduce() | |
{ | |
xtype result(m_shape, 0); | |
for (std::size_t i = 0; i < m_shape[0]; ++i) | |
{ | |
std::size_t idx = i; | |
std::size_t j = 0; | |
for (; j < m_e.shape()[1]; ++j) | |
{ | |
result.data()[i] += m_e.data()[idx]; | |
idx += m_stride; | |
} | |
} | |
return result; | |
} | |
private: | |
CT m_e; | |
std::size_t m_stride; | |
std::vector<std::size_t> m_shape; | |
}; | |
template<class E1, class X> | |
auto logsumexp2 (const E1& e1, X const& axes) { | |
using value_type = typename std::decay_t<E1>::value_type; | |
auto&& max = xt::eval(xt::amax(e1, axes)); | |
auto sv = xt::slice_vector(max); | |
for (int i = 0; i < e1.dimension(); i++) | |
{ | |
if (std::find (axes.begin(), axes.end(), i) != axes.end()) | |
sv.push_back(xt::newaxis()); | |
else | |
sv.push_back(e1.shape()[i]); | |
} | |
auto max2 = xt::dynamic_view(max, sv); | |
return xt::pyarray<value_type>(max2 + xt::log(xt::sum(xt::exp(e1 - max2), axes))); | |
} | |
template<class value_type> | |
auto normalize (xt::pyarray<value_type> const& e1) { | |
auto shape = std::vector<size_t>{e1.shape().size() - 1}; | |
auto ls = logsumexp2(e1, shape); | |
auto sv = xt::slice_vector(ls); | |
for (int i = 0; i < e1.dimension() - 1; i++) | |
sv.push_back (xt::all()); | |
sv.push_back (xt::newaxis()); | |
auto ls2 = xt::dynamic_view (ls, sv); | |
return xt::pyarray<value_type>(e1 - ls2); | |
} | |
PYBIND11_PLUGIN (logsumexp) { | |
if (_import_array() < 0) { | |
PyErr_SetString(PyExc_ImportError, "numpy.core.multiarray failed to import"); | |
return nullptr; | |
} | |
py::module m("logsumexp", "pybind11 example plugin"); | |
m.def("logsumexp", [](const xt::pyarray<double>& x) { | |
return xt::pyarray<double>(xt::eval (logsumexp1 (x))); | |
}); | |
m.def("amax", [](const xt::pyarray<double>& x, const std::vector<size_t>& ax) { | |
return xt::pyarray<double>(xt::amax(x, ax)); | |
}); | |
m.def("log", [](const xt::pyarray<double>& x) { | |
return xt::pyarray<double>(xt::log(x)); | |
}); | |
m.def("sum", [](const xt::pyarray<double>& x, const std::vector<size_t>& ax) { | |
return xt::pyarray<double>(xt::sum(x, ax)); | |
}); | |
m.def("stupid_sum", [](const xt::pyarray<double>& x) { | |
auto sr = stupid_reduce<xt::xclosure_t<decltype(x)>>(x); | |
return sr.reduce(); | |
}); | |
m.def("logsumexp", [](const xt::pyarray<double>& x, const std::vector<size_t>& ax) { | |
return logsumexp2 (x, ax); | |
}); | |
m.def("normalize", [](const xt::pyarray<double>& x) { | |
return normalize (x); | |
}); | |
return m.ptr(); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from logsumexp import logsumexp, normalize | |
import logsumexp as lse | |
from scipy.misc import logsumexp as scipy_logsumexp | |
from timeit import timeit | |
u = np.ones ((10, 100000)) | |
print(lse.amax(u, (1, ))) | |
print(lse.sum(u, (1, ))) | |
print(lse.stupid_sum(u)) | |
print (timeit('scipy_logsumexp(u, (1,))', 'from __main__ import scipy_logsumexp, u', number=10)) | |
print (timeit('logsumexp(u, (1,))', 'from __main__ import logsumexp, u', number=10)) | |
print ("Xtensor: amax ", timeit('lse.amax(u, (1,))', 'from __main__ import lse, u', number=10)) | |
print ("Xtensor: sum ", timeit('lse.sum(u, (1,))', 'from __main__ import lse, u', number=10)) | |
print ("Numpy : amax ", timeit('np.amax(u, (1,))', 'from __main__ import np, u', number=10)) | |
print ("Numpy : sum ", timeit('np.sum(u, (1,))', 'from __main__ import np, u', number=10)) | |
print ("Xtensor: ssum ", timeit('lse.stupid_sum(u)', 'from __main__ import lse, u', number=10)) | |
print ("Numpy : log ", timeit('np.log(u)', 'from __main__ import np, u', number=10)) | |
print ("Xtensor: log ", timeit('lse.log(u)', 'from __main__ import lse, u', number=10)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment