Skip to content

Instantly share code, notes, and snippets.

@nbecker
Last active April 17, 2017 12:49
Show Gist options
  • Save nbecker/7f13da1a108e956fdcea7915b29085f2 to your computer and use it in GitHub Desktop.
Save nbecker/7f13da1a108e956fdcea7915b29085f2 to your computer and use it in GitHub Desktop.
logsumexp.cc
#include <numpy/arrayobject.h>
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "xtensor/xarray.hpp"
#include "xtensor/xtensor.hpp"
#include "xtensor/xcontainer.hpp"
#include "xtensor/xbroadcast.hpp"
//#include "xtensor/xbuilder.hpp"
#include "xtensor/xview.hpp"
#include "xtensor/xeval.hpp"
#include "xtensor/xstridedview.hpp"
#include "xtensor-python/pyarray.hpp"
#include "xtensor-python/pytensor.hpp"
#include <algorithm> // ?
namespace py = pybind11;
template<class E1>
auto logsumexp1 (E1 const& e1) {
using value_type = typename std::decay_t<E1>::value_type;
auto max = xt::amax (e1)();
return std::move (max + xt::log (xt::sum (xt::exp (e1-max))));
}
template<class E1, class X>
auto logsumexp2 (const E1& e1, X const& axes) {
using value_type = typename std::decay_t<E1>::value_type;
auto max = xt::eval(xt::amax(e1, axes));
auto sv = xt::slice_vector(max);
for (int i = 0; i < e1.dimension(); i++)
{
if (std::find (axes.begin(), axes.end(), i) != axes.end())
sv.push_back(xt::newaxis());
else
sv.push_back(e1.shape()[i]);
}
auto max2 = xt::eval (xt::dynamic_view(max, sv));
return (xt::pyarray<value_type>(max2 + xt::log(xt::sum(xt::exp(e1 - max2), axes))));
}
template<class value_type>
auto normalize (xt::pyarray<value_type> const& e1) {
auto shape = std::vector<size_t>{e1.shape().size()-1};
auto ls = logsumexp2 (e1, shape);
auto sv = xt::slice_vector(ls);
for (int i = 0; i < e1.dimension()-1; i++)
sv.push_back (xt::all());
sv.push_back (xt::newaxis());
auto ls2 = xt::dynamic_view (ls, sv);
return xt::pyarray<value_type> ((e1 - ls2));
//return ls;
}
PYBIND11_PLUGIN (logsumexp) {
if (_import_array() < 0) {
PyErr_SetString(PyExc_ImportError, "numpy.core.multiarray failed to import");
return nullptr;
}
py::module m("logsumexp", "pybind11 example plugin");
m.def("logsumexp", [](xt::pyarray<double>const& x) {
return xt::pyarray<double> (xt::eval (logsumexp1 (x)));
});
m.def("logsumexp", [](xt::pyarray<double>const& x, std::vector<size_t>const& ax) {
//return xt::pyarray<double> ( (logsumexp2 (x, ax)));
return logsumexp2 (x, ax);
});
m.def("normalize", [](xt::pyarray<double>const& x) {
return normalize (x);
});
return m.ptr();
}
import numpy as np
from xtensor_test.logsumexp import logsumexp, normalize
def logsumexp_py1 (u):
m = np.max (u)
return m + np.log (np.sum (np.exp (u - m)))
def logsumexp_py2 (u, axes):
m = np.max (u, axes)
slices = [slice(m.shape[i]) if i not in axes else np.newaxis for i in range (len (u.shape))]
m2 = m[slices]
#print (m.shape, u.shape, m2)
return m2 + np.log (np.sum (np.exp (u - m2), axes))
def logsumexp_py (u, axes=None):
if axes == None:
return logsumexp_py1 (u)
else:
return logsumexp_py2 (u, axes)
u = np.ones (4)
v = logsumexp(u)
print (v)
print (logsumexp_py (u, (0,)))
print (logsumexp (u, (0,)))
print ('logsumexp (np.ones ((2,4))):', logsumexp (np.ones ((2,4))))
print (logsumexp_py (np.ones ((2,4))))
print ('logsumexp_py (np.ones ((2,4)), (1,))):', logsumexp_py (np.ones ((2,4)), (1,)))
print ('logsumexp (np.ones ((2,4)), (1,)):', logsumexp (np.ones ((2,4)), (1,)))
#print (np.ones ((2,4)))
print ('norm:', normalize (np.ones ((2,4))))
from scipy.misc import logsumexp as logsumexp2
print (logsumexp2 (u))
def normalize2 (u):
print ('ls:', logsumexp2 (u, axis=-1))
print ('u:', u)
return u - logsumexp2 (u, axis=-1)[...,np.newaxis]
print ('norm2:', normalize2 (np.ones ((2,4))))
w = np.ones ((2,2,4))
print ('logsumexp (np.ones (2,2,4)), (2,)):', logsumexp (np.ones ((2,2,4)), (2,)))
print ('logsumexp2 (np.ones (2,2,4)), (2,)):', logsumexp2 (np.ones ((2,2,4)), (2,)))
from timeit import timeit
u = np.ones ((2, 100000))
print (timeit ('logsumexp2 (u, (1,))', 'from __main__ import logsumexp2, u', number=10))
print (timeit ('logsumexp (u, (1,))', 'from __main__ import logsumexp, u', number=10))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment