-
-
Save nbecker/af2011043a29e1635adfaad73c0e4e9e to your computer and use it in GitHub Desktop.
// #include "pybind11/pybind11.h" | |
// #include "pybind11/stl.h" | |
#include "xtensor/xarray.hpp" | |
#include "xtensor/xtensor.hpp" | |
#include "xtensor/xcontainer.hpp" | |
#include "xtensor/xbroadcast.hpp" | |
//#include "xtensor/xbuilder.hpp" | |
#include "xtensor/xview.hpp" | |
#include "xtensor/xeval.hpp" | |
#include "xtensor/xstridedview.hpp" | |
#include "jlcxx/jlcxx.hpp" | |
#include "xtensor-julia/jltensor.hpp" // Import the jltensor container definition | |
#include "xtensor-julia/jlarray.hpp" // Import the jltensor container definition | |
#include <algorithm> // ? | |
//namespace py = pybind11; | |
template<class E1> | |
auto logsumexp1 (E1 const& e1) { | |
using value_type = typename std::decay_t<E1>::value_type; | |
auto max = xt::amax (e1)(); | |
return std::move (max + xt::log (xt::sum (xt::exp (e1-max)))); | |
} | |
template<class E1, class X> | |
auto logsumexp2 (E1 const& e1, X const& axes) { | |
using value_type = typename std::decay_t<E1>::value_type; | |
auto max = xt::eval(xt::amax(e1, axes)); | |
auto sv = xt::slice_vector(max); | |
for (int i = 0; i < e1.dimension(); i++) | |
{ | |
if (std::find (axes.begin(), axes.end(), i) != axes.end()) | |
sv.push_back(xt::newaxis()); | |
else | |
sv.push_back(xt::all()); | |
} | |
auto max2 = xt::dynamic_view(max, sv); | |
return xt::jlarray<value_type>(max + xt::log(xt::sum(xt::exp(e1 - max2), axes))); | |
} | |
// template<class value_type> | |
// auto normalize (xt::pyarray<value_type> const& e1) { | |
// auto shape = std::vector<size_t>{e1.shape().size()-1}; | |
// auto ls = logsumexp2 (e1, shape); | |
// auto sv = xt::slice_vector(ls); | |
// for (int i = 0; i < e1.dimension()-1; i++) | |
// sv.push_back (xt::all()); | |
// sv.push_back (xt::newaxis()); | |
// auto ls2 = xt::dynamic_view (ls, sv); | |
// return xt::pyarray<value_type> ((e1 - ls2)); | |
// //return ls; | |
// } | |
JULIA_CPP_MODULE_BEGIN(registry) | |
jlcxx::Module& m = registry.create_module("logsumexpnb"); | |
m.method("logsumexp", [](xt::jlarray<double> x) { | |
return xt::jlarray<double> (logsumexp1 (x)); | |
}); | |
m.method("logsumexp", [](xt::jlarray<double> x, xt::jltensor<size_t,1> ax) { | |
//return xt::pyarray<double> ( (logsumexp2 (x, ax))); | |
return logsumexp2 (x, ax); | |
}); | |
// m.method("normalize", [](xt::pyarray<double>const& x) { | |
// return normalize (x); | |
// }); | |
JULIA_CPP_MODULE_END | |
using CxxWrap | |
using Xtensor | |
wrap_modules("./liblogsumexp") | |
using logsumexpnb | |
import StatsFuns.logsumexp | |
# function logsumexp(A::AbstractArray, Dims) | |
# return mapslices(logsumexp, A, Dims) | |
# end | |
function logsumexp(u::AbstractArray, axes) | |
m = maximum(u, axes) | |
return m .+ log.(sum(exp.(u .- m), axes)) | |
end | |
u = ones(4) | |
v = logsumexpnb.logsumexp(u) | |
w = logsumexp(u) | |
x = ones(4,4) | |
y = logsumexpnb.logsumexp(u,Array{UInt64}([1])) |
In return xt::jlarray<double> (xt::eval(logsumexp1(x)));
, you should not need the eval
. A jlarray
can be initialized from any expression, so can be initialized from logsumexp1 (x)
.
By using eval
, you are creating a temporary xarray
.
Thanks for the suggestions. I tried updating logsumexp.cc accordingly, now I have a new behavior.
The test:
u = ones(4)
v = logsumexpnb.logsumexp(u)
Now fails to terminate (runs forever)
@nbecker, I just run your original code.
In the test_logsumexp.jl
you call the function y = logsumexpnb.logsumexp(u,Array{UInt64}([1]))
with u
instead of x
.
I.e. there is one dimension less than expected by the algorithm.
When calling with x
it works for me.
We should probably throw an appropriate runtime error or a warning, though.
However, we also fixed some memory issues in the reducers with the latest release.
Cheers,
Wolf
I am going to close this, and refer to the following issue for further reference: xtensor-stack/xtensor#629
So the only place where passing arguments by value is a requirement will be the functions passed to
m.method(
, where the actually argument used for the call are of typejl_array_t*
and converted tojlarray
.