Last active
December 12, 2015 12:49
-
-
Save jey/4774546 to your computer and use it in GitHub Desktop.
Eigen + NumPy example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#define SKIP_NO_IMPORT | |
#include "common.hpp" | |
#include <iostream> | |
namespace libtheta { | |
PyObjectRef moments_from_samples(const PyObjectRef &self, const PyObjectRef &args) | |
{ | |
PyArrayObject *weights_ = NULL; | |
PyArrayObject *samples_ = NULL; | |
if(!PyArg_ParseTuple(args.borrow(), "O&O&", const_matrix_converter, &weights_, | |
const_matrix_converter, &samples_)) { | |
return nullptr; | |
} | |
const auto &weights = const_pyvec(newref(weights_)); | |
const auto &samples = const_pymat(newref(samples_)); | |
const unsigned dim = samples.rows(); | |
const unsigned num_samples = samples.cols(); | |
ENSURE(weights.size() == num_samples); | |
Real total_weight = 0.0; | |
Vector mean = Vector::Zero(dim); | |
Matrix covar_ = Matrix::Zero(dim, dim); | |
auto covar = covar_.selfadjointView<Lower>(); | |
for(unsigned i = 0; i < num_samples; ++i) { | |
const auto weight = weights[i]; | |
const auto sample = samples.col(i); | |
mean += weight * sample; | |
covar.rankUpdate(sample, weight); | |
total_weight += weight; | |
} | |
mean /= total_weight; | |
covar_ /= total_weight; | |
covar.rankUpdate(mean, -1); | |
const auto &o_mean = pyvec(mean); | |
const auto &o_covar = pymat(Matrix(covar)); | |
return newref(PyTuple_Pack(2, o_mean.borrow(), o_covar.borrow())); | |
} | |
PyMethodDef python_methods[] = { | |
{ "moments_from_samples", pywrap<moments_from_samples>, METH_VARARGS, NULL }, | |
{ 0 } | |
}; | |
extern "C" | |
PyMODINIT_FUNC initlibtheta() | |
{ | |
using namespace libtheta; | |
import_array(); | |
auto mod = Py_InitModule3("libtheta", python_methods, "libtheta"); | |
} | |
} // namespace libtheta |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "common.hpp" | |
namespace libtheta { | |
Map<Vector, Aligned> pyvec(const PyArrayRef &ref) | |
{ | |
auto obj = ref.borrow(); | |
if(!(PyArray_FLAGS(obj) & NPY_CARRAY)) throw std::runtime_error("unmappable ndarray"); | |
if(PyArray_TYPE(obj) != NPY_DOUBLE) throw std::runtime_error("ndarray is not of type double"); | |
if(PyArray_NDIM(obj) != 1) throw std::runtime_error("expected a vector"); | |
double *data = reinterpret_cast<double*>(PyArray_DATA(obj)); | |
const auto rows = PyArray_DIM(obj, 0); | |
return Map<Vector, Aligned>(data, rows); | |
} | |
Map<const Vector, Aligned> const_pyvec(const PyArrayRef &ref) | |
{ | |
auto obj = ref.borrow(); | |
if(!(PyArray_FLAGS(obj) & NPY_CARRAY_RO)) throw std::runtime_error("unmappable ndarray"); | |
if(PyArray_TYPE(obj) != NPY_DOUBLE) throw std::runtime_error("ndarray is not of type double"); | |
if(PyArray_NDIM(obj) != 1) throw std::runtime_error("expected a vector"); | |
const double *data = reinterpret_cast<const double*>(PyArray_DATA(obj)); | |
const auto rows = PyArray_DIM(obj, 0); | |
return Map<const Vector, Aligned>(data, rows); | |
} | |
Map<Matrix, Aligned> pymat(const PyArrayRef &ref) | |
{ | |
auto obj = ref.borrow(); | |
if(!(PyArray_FLAGS(obj) & NPY_CARRAY)) throw std::runtime_error("unmappable ndarray"); | |
if(PyArray_TYPE(obj) != NPY_DOUBLE) throw std::runtime_error("ndarray is not of type double"); | |
if(PyArray_NDIM(obj) != 2) throw std::runtime_error("expected a matrix"); | |
double *data = reinterpret_cast<double*>(PyArray_DATA(obj)); | |
const auto rows = PyArray_DIM(obj, 0); | |
const auto cols = PyArray_DIM(obj, 1); | |
return Map<Matrix, Aligned>(data, rows, cols); | |
} | |
Map<const Matrix, Aligned> const_pymat(const PyArrayRef &ref) | |
{ | |
auto obj = ref.borrow(); | |
if(!(PyArray_FLAGS(obj) & NPY_CARRAY_RO)) throw std::runtime_error("unmappable ndarray"); | |
if(PyArray_TYPE(obj) != NPY_DOUBLE) throw std::runtime_error("ndarray is not of type double"); | |
if(PyArray_NDIM(obj) != 2) throw std::runtime_error("expected a matrix"); | |
const double *data = reinterpret_cast<const double*>(PyArray_DATA(obj)); | |
const auto rows = PyArray_DIM(obj, 0); | |
const auto cols = PyArray_DIM(obj, 1); | |
return Map<const Matrix, Aligned>(data, rows, cols); | |
} | |
int const_matrix_converter(PyObject *obj, void *dest_) | |
{ | |
PyObject **dest = reinterpret_cast<PyObject**>(dest_); | |
*dest = PyArray_FromAny(obj, PyArray_DescrFromType(NPY_DOUBLE), 1, 2, NPY_IN_ARRAY, NULL); | |
return 1; | |
} | |
int matrix_converter(PyObject *obj, void *dest_) | |
{ | |
PyObject **dest = reinterpret_cast<PyObject**>(dest_); | |
*dest = PyArray_FromAny(obj, PyArray_DescrFromType(NPY_DOUBLE), 1, 2, NPY_INOUT_ARRAY, NULL); | |
return 1; | |
} | |
} // namespace libtheta |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef LIBTHETA__COMMON_H | |
#define LIBTHETA__COMMON_H | |
// Don't ask. | |
#define PY_ARRAY_UNIQUE_SYMBOL libtheta_PyArray_API | |
#ifndef SKIP_NO_IMPORT | |
#define NO_IMPORT | |
#endif | |
#include <Python.h> | |
#include <numpy/arrayobject.h> | |
#include <Eigen/Dense> | |
#include <cassert> | |
#include <cstddef> | |
#include <stdexcept> | |
#include <cerrno> | |
#include <string> | |
#include <iostream> | |
#define UNLIKELY(x) __builtin_expect(!!(x), 0) | |
#ifdef NDEBUG | |
#define ENSURE(x) do { if(UNLIKELY(!(x))) throw std::runtime_error("FATAL: " #x); } while(false) | |
#else | |
#define ENSURE(x) assert(x) | |
#endif | |
namespace libtheta { | |
typedef double Real; | |
typedef Eigen::Matrix<Real, Eigen::Dynamic, 1> Vector; | |
typedef Eigen::Matrix<Real, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix; | |
typedef Eigen::Matrix<Real, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor> ColMajorMatrix; | |
typedef RowMajorMatrix Matrix; | |
using namespace Eigen; | |
class ErrnoError : public std::runtime_error | |
{ | |
public: | |
ErrnoError(const std::string &prefix = "") : | |
std::runtime_error(make_error_message(prefix, errno)) | |
{ | |
} | |
private: | |
static std::string make_error_message(const std::string &prefix, int errnum) | |
{ | |
std::string msg; | |
if(!prefix.empty()) msg = prefix + ": "; | |
msg.append(strerror(errnum)); | |
return msg; | |
} | |
}; | |
class PythonError : public std::exception | |
{ | |
public: | |
PythonError() | |
{ | |
ENSURE(PyErr_Occurred()); | |
} | |
PythonError(PyObject *type, const std::string &msg) | |
{ | |
PyErr_SetString(type, msg.c_str()); | |
} | |
}; | |
// note: return values of const methods are non-const because Python/C API is not const-aware | |
template <typename PyObjectT> | |
class PyRef | |
{ | |
typedef void (PyRef<PyObjectT>::*dummy_bool_type)() const; | |
void dummy_true() const {} | |
public: | |
static PyRef newref(PyObjectT *obj) | |
{ | |
return PyRef(obj); | |
} | |
static PyRef borrow(PyObjectT *obj) | |
{ | |
Py_XINCREF(obj); | |
return PyRef(obj); | |
} | |
PyRef() : | |
obj(NULL) | |
{ | |
} | |
PyRef(decltype(nullptr)) : | |
obj(NULL) | |
{ | |
} | |
PyRef(const PyRef &other) : | |
obj(other.obj) | |
{ | |
Py_XINCREF(obj); | |
} | |
PyRef(PyRef &&other) : | |
obj(NULL) | |
{ | |
std::swap(obj, other.obj); | |
} | |
~PyRef() | |
{ | |
Py_XDECREF(obj); | |
} | |
PyRef& operator= (const PyRef &other) | |
{ | |
Py_XINCREF(other.obj); | |
Py_XDECREF(obj); | |
obj = other.obj; | |
return *this; | |
} | |
PyRef& operator= (PyRef &&other) | |
{ | |
std::swap(obj, other.obj); | |
} | |
PyObjectT* newref() const | |
{ | |
Py_XINCREF(obj); | |
return obj; | |
} | |
PyObjectT* borrow() const | |
{ | |
return obj; | |
} | |
PyObjectT& operator* () const | |
{ | |
return *obj; | |
} | |
PyObjectT* operator-> () const | |
{ | |
return obj; | |
} | |
template <typename TargetT> | |
PyRef<TargetT> cast() const | |
{ | |
return PyRef<TargetT>::borrow(reinterpret_cast<TargetT*>(obj)); | |
} | |
// uses the "safe bool" idiom: http://www.artima.com/cppsource/safebool.html | |
operator dummy_bool_type() const | |
{ | |
return obj ? &PyRef::dummy_true : 0; | |
} | |
private: | |
PyRef(PyObjectT *obj_) : | |
obj(obj_) | |
{ | |
} | |
PyObjectT *obj; | |
}; | |
template <typename PyObjectT> | |
PyRef<PyObjectT> newref(PyObjectT *obj) | |
{ | |
return PyRef<PyObjectT>::newref(obj); | |
} | |
template <typename PyObjectT> | |
PyRef<PyObjectT> borrow(PyObjectT *obj) | |
{ | |
return PyRef<PyObjectT>::borrow(obj); | |
} | |
typedef PyRef<PyObject> PyObjectRef; | |
typedef PyRef<PyArrayObject> PyArrayRef; | |
typedef PyObjectRef PyFunc(const PyObjectRef &self, const PyObjectRef &args); | |
typedef PyObjectRef PyFuncWithKeywords(const PyObjectRef &self, const PyObjectRef &args, const PyObjectRef &kwargs); | |
template <PyFunc func> | |
PyObject* pywrap(PyObject *self, PyObject *args) | |
{ | |
try { | |
return func(borrow(self), borrow(args)).newref(); | |
} catch(const PythonError &e) { | |
assert(PyErr_Occurred()); | |
return NULL; | |
} catch(const std::exception &e) { | |
PyErr_SetString(PyExc_RuntimeError, e.what()); | |
return NULL; | |
} catch(...) { | |
PyErr_SetString(PyExc_RuntimeError, "<unknown error>"); | |
return NULL; | |
} | |
} | |
template <PyFuncWithKeywords func> | |
PyObject* pywrap(PyObject *self, PyObject *args, PyObject *kwargs) | |
{ | |
try { | |
return func(borrow(self), borrow(args), borrow(kwargs)).newref(); | |
} catch(const PythonError &e) { | |
assert(PyErr_Occurred()); | |
return NULL; | |
} catch(const std::exception &e) { | |
PyErr_SetString(PyExc_RuntimeError, e.what()); | |
return NULL; | |
} catch(...) { | |
PyErr_SetString(PyExc_RuntimeError, "<unknown error>"); | |
return NULL; | |
} | |
} | |
template <PyFuncWithKeywords func> | |
int pyinitproc(PyObject *self, PyObject *args, PyObject *kwargs) | |
{ | |
try { | |
return func(borrow(self), borrow(args), borrow(kwargs)).newref() == NULL ? -1 : 0; | |
} catch(const PythonError &e) { | |
assert(PyErr_Occurred()); | |
return -1; | |
} catch(const std::exception &e) { | |
PyErr_SetString(PyExc_RuntimeError, e.what()); | |
return -1; | |
} catch(...) { | |
PyErr_SetString(PyExc_RuntimeError, "<unknown error>"); | |
return -1; | |
} | |
} | |
int matrix_converter(PyObject *obj, void *dest); | |
int const_matrix_converter(PyObject *obj, void *dest); | |
Map<Vector, Aligned> pyvec(const PyArrayRef &obj); | |
Map<Matrix, Aligned> pymat(const PyArrayRef &obj); | |
Map<const Vector, Aligned> const_pyvec(const PyArrayRef &obj); | |
Map<const Matrix, Aligned> const_pymat(const PyArrayRef &obj); | |
inline PyObjectRef pyfloat(Real value) | |
{ | |
return newref(PyFloat_FromDouble(value)); | |
} | |
template <typename Derived> | |
PyArrayRef pyvec(const DenseBase<Derived> &vec) | |
{ | |
ENSURE(vec.cols() == 1); | |
npy_intp dims[] = { vec.size() }; | |
PyArrayRef result = newref(PyArray_EMPTY(1, dims, NPY_DOUBLE, 0)).template cast<PyArrayObject>(); | |
ENSURE(PyArray_Check(result.borrow())); | |
pyvec(result) = vec; | |
return result; | |
} | |
template <typename Derived> | |
PyArrayRef pymat(const MatrixBase<Derived> &mat) | |
{ | |
npy_intp dims[] = { mat.rows(), mat.cols() }; | |
PyArrayRef result = newref(PyArray_EMPTY(2, dims, NPY_DOUBLE, 0)).template cast<PyArrayObject>(); | |
ENSURE(PyArray_Check(result.borrow())); | |
pymat(result) = mat; | |
return result; | |
} | |
} // namespace libtheta | |
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CFLAGS=-Wall $(shell python-config --includes) -Wno-unused -pthread -fPIC -shared -ggdb3 -O3 -march=native | |
CXXFLAGS=-std=gnu++0x $(CFLAGS) | |
LDFLAGS=$(shell python-config --ldflags) -lmkl_rt | |
OBJS=0-example.o common.o | |
.PHONY: check | |
check: libtheta.so | |
python test.py | |
libtheta.so: $(OBJS) | |
$(CXX) -shared $(CFLAGS) -o $@ $^ $(LDFLAGS) | |
.PHONY: clean | |
clean: | |
rm -f $(OBJS) libtheta.so | |
find . -name "*.pyc" -type f -exec rm -f {} ";" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from libtheta import moments_from_samples | |
from scipy import random, linalg, eye | |
N = 10000 | |
D = 10 | |
mean, covar = moments_from_samples(random.rand(N), random.randn(D, N)) | |
print "residual norm^2:", linalg.norm(mean) ** 2, linalg.norm(covar - eye(D)) ** 2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment