Last active
February 5, 2018 16:56
-
-
Save nbecker/af2011043a29e1635adfaad73c0e4e9e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// #include "pybind11/pybind11.h" | |
// #include "pybind11/stl.h" | |
#include "xtensor/xarray.hpp" | |
#include "xtensor/xtensor.hpp" | |
#include "xtensor/xcontainer.hpp" | |
#include "xtensor/xbroadcast.hpp" | |
//#include "xtensor/xbuilder.hpp" | |
#include "xtensor/xview.hpp" | |
#include "xtensor/xeval.hpp" | |
#include "xtensor/xstridedview.hpp" | |
#include "jlcxx/jlcxx.hpp" | |
#include "xtensor-julia/jltensor.hpp" // Import the jltensor container definition | |
#include "xtensor-julia/jlarray.hpp" // Import the jltensor container definition | |
#include <algorithm> // ? | |
//namespace py = pybind11; | |
template<class E1> | |
auto logsumexp1 (E1 const& e1) { | |
using value_type = typename std::decay_t<E1>::value_type; | |
auto max = xt::amax (e1)(); | |
return std::move (max + xt::log (xt::sum (xt::exp (e1-max)))); | |
} | |
template<class E1, class X> | |
auto logsumexp2 (E1 const& e1, X const& axes) { | |
using value_type = typename std::decay_t<E1>::value_type; | |
auto max = xt::eval(xt::amax(e1, axes)); | |
auto sv = xt::slice_vector(max); | |
for (int i = 0; i < e1.dimension(); i++) | |
{ | |
if (std::find (axes.begin(), axes.end(), i) != axes.end()) | |
sv.push_back(xt::newaxis()); | |
else | |
sv.push_back(xt::all()); | |
} | |
auto max2 = xt::dynamic_view(max, sv); | |
return xt::jlarray<value_type>(max + xt::log(xt::sum(xt::exp(e1 - max2), axes))); | |
} | |
// template<class value_type> | |
// auto normalize (xt::pyarray<value_type> const& e1) { | |
// auto shape = std::vector<size_t>{e1.shape().size()-1}; | |
// auto ls = logsumexp2 (e1, shape); | |
// auto sv = xt::slice_vector(ls); | |
// for (int i = 0; i < e1.dimension()-1; i++) | |
// sv.push_back (xt::all()); | |
// sv.push_back (xt::newaxis()); | |
// auto ls2 = xt::dynamic_view (ls, sv); | |
// return xt::pyarray<value_type> ((e1 - ls2)); | |
// //return ls; | |
// } | |
JULIA_CPP_MODULE_BEGIN(registry) | |
jlcxx::Module& m = registry.create_module("logsumexpnb"); | |
m.method("logsumexp", [](xt::jlarray<double> x) { | |
return xt::jlarray<double> (logsumexp1 (x)); | |
}); | |
m.method("logsumexp", [](xt::jlarray<double> x, xt::jltensor<size_t,1> ax) { | |
//return xt::pyarray<double> ( (logsumexp2 (x, ax))); | |
return logsumexp2 (x, ax); | |
}); | |
// m.method("normalize", [](xt::pyarray<double>const& x) { | |
// return normalize (x); | |
// }); | |
JULIA_CPP_MODULE_END | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using CxxWrap | |
using Xtensor | |
wrap_modules("./liblogsumexp") | |
using logsumexpnb | |
import StatsFuns.logsumexp | |
# function logsumexp(A::AbstractArray, Dims) | |
# return mapslices(logsumexp, A, Dims) | |
# end | |
function logsumexp(u::AbstractArray, axes) | |
m = maximum(u, axes) | |
return m .+ log.(sum(exp.(u .- m), axes)) | |
end | |
u = ones(4) | |
v = logsumexpnb.logsumexp(u) | |
w = logsumexp(u) | |
x = ones(4,4) | |
y = logsumexpnb.logsumexp(u,Array{UInt64}([1])) |
Thanks for the suggestions. I tried updating logsumexp.cc accordingly, now I have a new behavior.
The test:
u = ones(4)
v = logsumexpnb.logsumexp(u)
Now fails to terminate (runs forever)
@nbecker, I just run your original code.
In the test_logsumexp.jl
you call the function y = logsumexpnb.logsumexp(u,Array{UInt64}([1]))
with u
instead of x
.
I.e. there is one dimension less than expected by the algorithm.
When calling with x
it works for me.
We should probably throw an appropriate runtime error or a warning, though.
However, we also fixed some memory issues in the reducers with the latest release.
Cheers,
Wolf
I am going to close this, and refer to the following issue for further reference: xtensor-stack/xtensor#629
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
In
return xt::jlarray<double> (xt::eval(logsumexp1(x)));
, you should not need theeval
. Ajlarray
can be initialized from any expression, so can be initialized fromlogsumexp1 (x)
.By using
eval
, you are creating a temporaryxarray
.