Skip to content

@lucastheis /gsm.h
Created

Embed URL

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
#ifndef GSM_H
#define GSM_H
#include "Eigen/Core"
#include "distribution.h"
#include "exception.h"
#include <iostream>
#include <cmath>
using namespace Eigen;
using std::sqrt;
class GSM : public Distribution {
public:
GSM(int dim = 1, int numScales = 10);
inline int dim();
inline int numScales();
inline ArrayXd scales();
inline void setScales(MatrixXd scales);
inline double variance();
inline void normalize();
virtual bool train(const MatrixXd& data, int maxIter = 100, double tol = 1e-5);
virtual MatrixXd sample(int numSamples = 1);
virtual Array<double, 1, Dynamic> samplePosterior(const MatrixXd& data);
virtual ArrayXXd posterior(const MatrixXd& data);
virtual ArrayXXd posterior(const MatrixXd& data, const RowVectorXd& sqNorms);
virtual ArrayXXd logJoint(const MatrixXd& data);
virtual ArrayXXd logJoint(const MatrixXd& data, const RowVectorXd& sqNorms);
virtual Array<double, 1, Dynamic> logLikelihood(const MatrixXd& data);
virtual Array<double, 1, Dynamic> logLikelihood(const MatrixXd& data, const RowVectorXd& sqNorms);
virtual Array<double, 1, Dynamic> energy(const MatrixXd& data);
virtual Array<double, 1, Dynamic> energy(const MatrixXd& data, const RowVectorXd& sqNorms);
virtual ArrayXXd energyGradient(const MatrixXd& data);
protected:
int mDim;
int mNumScales;
ArrayXd mScales;
};
inline int GSM::dim() {
return mDim;
}
inline int GSM::numScales() {
return mNumScales;
}
inline ArrayXd GSM::scales() {
return mScales;
}
inline double GSM::variance() {
return mScales.square().mean();
}
inline void GSM::normalize() {
mScales /= sqrt(variance());
}
inline void GSM::setScales(MatrixXd scales) {
// turn row vector into column vector
if(scales.cols() > scales.rows())
scales.transposeInPlace();
if(scales.rows() != mNumScales || scales.cols() != 1)
throw Exception("Wrong number of scales.");
mScales = scales;
}
#endif
#ifndef GSMINTERFACE_H
#define GSMINTERFACE_H
#include "gsm.h"
#include "exception.h"
#include "pyutils.h"
#include "Eigen/Core"
#include <iostream>
using namespace Eigen;
struct GSMObject {
PyObject_HEAD
GSM* gsm;
};
/**
* Create a new GSM object.
*/
static PyObject* GSM_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
PyObject* self = type->tp_alloc(type, 0);
if(self)
reinterpret_cast<GSMObject*>(self)->gsm = 0;
return self;
}
/**
* Initialize GSM object.
*/
static int GSM_init(GSMObject* self, PyObject* args, PyObject* kwds) {
char* kwlist[] = {"dim", "num_scales", 0};
int dim;
int num_scales = 10;
// read arguments
if(!PyArg_ParseTupleAndKeywords(args, kwds, "i|i", kwlist,
&dim, &num_scales))
return -1;
// create actual GSM instance
self->gsm = new GSM(dim, num_scales);
return 0;
}
/**
* Delete GSM object.
*/
static void GSM_dealloc(GSMObject* self) {
// delete actual GSM instance
delete self->gsm;
// delete GSM object
self->ob_type->tp_free(reinterpret_cast<PyObject*>(self));
}
/**
* Return number of visible units.
*/
static PyObject* GSM_dim(GSMObject* self, PyObject*, void*) {
return PyInt_FromLong(self->gsm->dim());
}
/**
* Return number of hidden units.
*/
static PyObject* GSM_num_scales(GSMObject* self, PyObject*, void*) {
return PyInt_FromLong(self->gsm->numScales());
}
/**
* Return copy of linear basis.
*/
static PyObject* GSM_scales(GSMObject* self, PyObject*, void*) {
return PyArray_FromMatrixXd(self->gsm->scales());
}
/**
* Replace linear basis.
*/
static int GSM_set_scales(GSMObject* self, PyObject* value, void*) {
if(!PyArray_Check(value)) {
PyErr_SetString(PyExc_TypeError, "Scales should be of type `ndarray`.");
return -1;
}
try {
self->gsm->setScales(PyArray_ToMatrixXd(value));
} catch(Exception exception) {
PyErr_SetString(PyExc_RuntimeError, exception.message());
return -1;
}
return 0;
}
static PyObject* GSM_variance(GSMObject* self, PyObject*, PyObject*) {
try {
return PyFloat_FromDouble(self->gsm->variance());
} catch(Exception exception) {
PyErr_SetString(PyExc_RuntimeError, exception.message());
return 0;
}
}
static PyObject* GSM_normalize(GSMObject* self, PyObject*, PyObject*) {
try {
self->gsm->normalize();
} catch(Exception exception) {
PyErr_SetString(PyExc_RuntimeError, exception.message());
return 0;
}
Py_INCREF(Py_None);
return Py_None;
}
static PyObject* GSM_train(GSMObject* self, PyObject* args, PyObject* kwds) {
char* kwlist[] = {"data", "max_iter", "tol", 0};
PyObject* data;
int max_iter = 100;
double tol = 1e-5;
// read arguments
if(!PyArg_ParseTupleAndKeywords(args, kwds, "O|id", kwlist, &data, &max_iter, &tol))
return 0;
// make sure data is stored in NumPy array
if(!PyArray_Check(data)) {
PyErr_SetString(PyExc_TypeError, "Data has to be stored in a NumPy array.");
return 0;
}
try {
if(self->gsm->train(PyArray_ToMatrixXd(data), max_iter, tol)) {
Py_INCREF(Py_True);
return Py_True;
} else {
Py_INCREF(Py_False);
return Py_False;
}
} catch(Exception exception) {
PyErr_SetString(PyExc_RuntimeError, exception.message());
return 0;
}
return 0;
}
static PyObject* GSM_posterior(GSMObject* self, PyObject* args, PyObject* kwds) {
char* kwlist[] = {"data", 0};
PyObject* data;
// read arguments
if(!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &data))
return 0;
// make sure data is stored in NumPy array
if(!PyArray_Check(data)) {
PyErr_SetString(PyExc_TypeError, "Data has to be stored in a NumPy array.");
return 0;
}
try {
return PyArray_FromMatrixXd(self->gsm->posterior(PyArray_ToMatrixXd(data)));
} catch(Exception exception) {
PyErr_SetString(PyExc_RuntimeError, exception.message());
return 0;
}
}
static PyObject* GSM_sample(GSMObject* self, PyObject* args, PyObject* kwds) {
char* kwlist[] = {"num_samples", 0};
int num_samples = 1;
if(!PyArg_ParseTupleAndKeywords(args, kwds, "|i", kwlist, &num_samples))
return 0;
try {
return PyArray_FromMatrixXd(self->gsm->sample(num_samples));
} catch(Exception exception) {
PyErr_SetString(PyExc_RuntimeError, exception.message());
return 0;
}
}
static PyObject* GSM_sample_posterior(GSMObject* self, PyObject* args, PyObject* kwds) {
char* kwlist[] = {"data", 0};
PyObject* data;
// read arguments
if(!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &data))
return 0;
// make sure data is stored in NumPy array
if(!PyArray_Check(data)) {
PyErr_SetString(PyExc_TypeError, "Data has to be stored in a NumPy array.");
return 0;
}
try {
return PyArray_FromMatrixXd(self->gsm->samplePosterior(PyArray_ToMatrixXd(data)));
} catch(Exception exception) {
PyErr_SetString(PyExc_RuntimeError, exception.message());
return 0;
}
}
static PyObject* GSM_loglikelihood(GSMObject* self, PyObject* args, PyObject* kwds) {
char* kwlist[] = {"data", 0};
PyObject* data;
// read arguments
if(!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &data))
return 0;
// make sure data is stored in NumPy array
if(!PyArray_Check(data)) {
PyErr_SetString(PyExc_TypeError, "Data has to be stored in a NumPy array.");
return 0;
}
try {
return PyArray_FromMatrixXd(self->gsm->logLikelihood(PyArray_ToMatrixXd(data)));
} catch(Exception exception) {
PyErr_SetString(PyExc_RuntimeError, exception.message());
return 0;
}
}
static PyObject* GSM_energy(GSMObject* self, PyObject* args, PyObject* kwds) {
char* kwlist[] = {"data", 0};
PyObject* data;
// read arguments
if(!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &data))
return 0;
// make sure data is stored in NumPy array
if(!PyArray_Check(data)) {
PyErr_SetString(PyExc_TypeError, "Data has to be stored in a NumPy array.");
return 0;
}
try {
return PyArray_FromMatrixXd(self->gsm->energy(PyArray_ToMatrixXd(data)));
} catch(Exception exception) {
PyErr_SetString(PyExc_RuntimeError, exception.message());
return 0;
}
}
static PyObject* GSM_energy_gradient(GSMObject* self, PyObject* args, PyObject* kwds) {
char* kwlist[] = {"data", 0};
PyObject* data;
// read arguments
if(!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &data))
return 0;
// make sure data is stored in NumPy array
if(!PyArray_Check(data)) {
PyErr_SetString(PyExc_TypeError, "Data has to be stored in a NumPy array.");
return 0;
}
try {
return PyArray_FromMatrixXd(self->gsm->energyGradient(PyArray_ToMatrixXd(data)));
} catch(Exception exception) {
PyErr_SetString(PyExc_RuntimeError, exception.message());
return 0;
}
}
static PyGetSetDef GSM_getset[] = {
{"dim", (getter)GSM_dim, 0, 0},
{"num_scales", (getter)GSM_num_scales, 0, 0},
{"scales", (getter)GSM_scales, (setter)GSM_set_scales, 0},
{0}
};
static PyMethodDef GSM_methods[] = {
{"train", (PyCFunction)GSM_train, METH_VARARGS|METH_KEYWORDS, 0},
{"posterior", (PyCFunction)GSM_posterior, METH_VARARGS|METH_KEYWORDS, 0},
{"variance", (PyCFunction)GSM_variance, METH_NOARGS, 0},
{"normalize", (PyCFunction)GSM_normalize, METH_NOARGS, 0},
{"sample", (PyCFunction)GSM_sample, METH_VARARGS|METH_KEYWORDS, 0},
{"sample_posterior", (PyCFunction)GSM_sample_posterior, METH_VARARGS|METH_KEYWORDS, 0},
{"loglikelihood", (PyCFunction)GSM_loglikelihood, METH_VARARGS|METH_KEYWORDS, 0},
{"energy", (PyCFunction)GSM_energy, METH_VARARGS|METH_KEYWORDS, 0},
{"energy_gradient", (PyCFunction)GSM_energy_gradient, METH_VARARGS|METH_KEYWORDS, 0},
{0}
};
static PyTypeObject GSM_type = {
PyObject_HEAD_INIT(0)
0, /*ob_size*/
"isa.GSM", /*tp_name*/
sizeof(GSMObject), /*tp_basicsize*/
0, /*tp_itemsize*/
(destructor)GSM_dealloc, /*tp_dealloc*/
0, /*tp_print*/
0, /*tp_getattr*/
0, /*tp_setattr*/
0, /*tp_compare*/
0, /*tp_repr*/
0, /*tp_as_number*/
0, /*tp_as_sequence*/
0, /*tp_as_mapping*/
0, /*tp_hash */
0, /*tp_call*/
0, /*tp_str*/
0, /*tp_getattro*/
0, /*tp_setattro*/
0, /*tp_as_buffer*/
Py_TPFLAGS_DEFAULT, /*tp_flags*/
0, /*tp_doc*/
0, /*tp_traverse*/
0, /*tp_clear*/
0, /*tp_richcompare*/
0, /*tp_weaklistoffset*/
0, /*tp_iter*/
0, /*tp_iternext*/
GSM_methods, /*tp_methods*/
0, /*tp_members*/
GSM_getset, /*tp_getset*/
0, /*tp_base*/
0, /*tp_dict*/
0, /*tp_descr_get*/
0, /*tp_descr_set*/
0, /*tp_dictoffset*/
(initproc)GSM_init, /*tp_init*/
0, /*tp_alloc*/
GSM_new, /*tp_new*/
};
#endif
#include <Python.h>
#include <arrayobject.h>
#include <structmember.h>
#include <stdlib.h>
#include <time.h>
#include "isainterface.h"
#include "gsminterface.h"
PyMODINIT_FUNC initisa() {
// set random seed
timeval time;
gettimeofday(&time, 0);
srand(time.tv_usec * time.tv_sec);
// initialize NumPy
import_array();
// create module object
PyObject* module = Py_InitModule("isa", 0);
// initialize types
if(PyType_Ready(&ISA_type) < 0)
return;
if(PyType_Ready(&GSM_type) < 0)
return;
// add types to module
Py_INCREF(&ISA_type);
PyModule_AddObject(module, "ISA", reinterpret_cast<PyObject*>(&ISA_type));
Py_INCREF(&GSM_type);
PyModule_AddObject(module, "GSM", reinterpret_cast<PyObject*>(&GSM_type));
}
import os
import numpy
from distutils.core import setup, Extension
from distutils.ccompiler import new_compiler
modules = [
Extension('isa',
language='c++',
sources=[
'code/isa/src/isa.cpp',
'code/isa/src/gsm.cpp',
'code/isa/src/utils.cpp',
'code/isa/src/module.cpp',
'code/isa/src/distribution.cpp'],
include_dirs=[
'code',
'code/isa/include',
os.path.join(numpy.__path__[0], 'core/include/numpy')],
library_dirs=[],
libraries=[],
extra_link_args=[
'-lgomp'],
extra_compile_args=[
'-fopenmp',
'-Wno-parentheses',
'-Wno-write-strings'])]
setup(
name='isa',
version='0.1',
description='',
ext_modules=modules)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Something went wrong with that request. Please try again.