Created
June 13, 2012 11:41
-
-
Save lucastheis/2923577 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef GSM_H | |
#define GSM_H | |
#include "Eigen/Core" | |
#include "distribution.h" | |
#include "exception.h" | |
#include <iostream> | |
#include <cmath> | |
using namespace Eigen; | |
using std::sqrt; | |
class GSM : public Distribution { | |
public: | |
GSM(int dim = 1, int numScales = 10); | |
inline int dim(); | |
inline int numScales(); | |
inline ArrayXd scales(); | |
inline void setScales(MatrixXd scales); | |
inline double variance(); | |
inline void normalize(); | |
virtual bool train(const MatrixXd& data, int maxIter = 100, double tol = 1e-5); | |
virtual MatrixXd sample(int numSamples = 1); | |
virtual Array<double, 1, Dynamic> samplePosterior(const MatrixXd& data); | |
virtual ArrayXXd posterior(const MatrixXd& data); | |
virtual ArrayXXd posterior(const MatrixXd& data, const RowVectorXd& sqNorms); | |
virtual ArrayXXd logJoint(const MatrixXd& data); | |
virtual ArrayXXd logJoint(const MatrixXd& data, const RowVectorXd& sqNorms); | |
virtual Array<double, 1, Dynamic> logLikelihood(const MatrixXd& data); | |
virtual Array<double, 1, Dynamic> logLikelihood(const MatrixXd& data, const RowVectorXd& sqNorms); | |
virtual Array<double, 1, Dynamic> energy(const MatrixXd& data); | |
virtual Array<double, 1, Dynamic> energy(const MatrixXd& data, const RowVectorXd& sqNorms); | |
virtual ArrayXXd energyGradient(const MatrixXd& data); | |
protected: | |
int mDim; | |
int mNumScales; | |
ArrayXd mScales; | |
}; | |
inline int GSM::dim() { | |
return mDim; | |
} | |
inline int GSM::numScales() { | |
return mNumScales; | |
} | |
inline ArrayXd GSM::scales() { | |
return mScales; | |
} | |
inline double GSM::variance() { | |
return mScales.square().mean(); | |
} | |
inline void GSM::normalize() { | |
mScales /= sqrt(variance()); | |
} | |
inline void GSM::setScales(MatrixXd scales) { | |
// turn row vector into column vector | |
if(scales.cols() > scales.rows()) | |
scales.transposeInPlace(); | |
if(scales.rows() != mNumScales || scales.cols() != 1) | |
throw Exception("Wrong number of scales."); | |
mScales = scales; | |
} | |
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef GSMINTERFACE_H | |
#define GSMINTERFACE_H | |
#include "gsm.h" | |
#include "exception.h" | |
#include "pyutils.h" | |
#include "Eigen/Core" | |
#include <iostream> | |
using namespace Eigen; | |
struct GSMObject { | |
PyObject_HEAD | |
GSM* gsm; | |
}; | |
/** | |
* Create a new GSM object. | |
*/ | |
static PyObject* GSM_new(PyTypeObject* type, PyObject* args, PyObject* kwds) { | |
PyObject* self = type->tp_alloc(type, 0); | |
if(self) | |
reinterpret_cast<GSMObject*>(self)->gsm = 0; | |
return self; | |
} | |
/** | |
* Initialize GSM object. | |
*/ | |
static int GSM_init(GSMObject* self, PyObject* args, PyObject* kwds) { | |
char* kwlist[] = {"dim", "num_scales", 0}; | |
int dim; | |
int num_scales = 10; | |
// read arguments | |
if(!PyArg_ParseTupleAndKeywords(args, kwds, "i|i", kwlist, | |
&dim, &num_scales)) | |
return -1; | |
// create actual GSM instance | |
self->gsm = new GSM(dim, num_scales); | |
return 0; | |
} | |
/** | |
* Delete GSM object. | |
*/ | |
static void GSM_dealloc(GSMObject* self) { | |
// delete actual GSM instance | |
delete self->gsm; | |
// delete GSM object | |
self->ob_type->tp_free(reinterpret_cast<PyObject*>(self)); | |
} | |
/** | |
* Return number of visible units. | |
*/ | |
static PyObject* GSM_dim(GSMObject* self, PyObject*, void*) { | |
return PyInt_FromLong(self->gsm->dim()); | |
} | |
/** | |
* Return number of hidden units. | |
*/ | |
static PyObject* GSM_num_scales(GSMObject* self, PyObject*, void*) { | |
return PyInt_FromLong(self->gsm->numScales()); | |
} | |
/** | |
* Return copy of linear basis. | |
*/ | |
static PyObject* GSM_scales(GSMObject* self, PyObject*, void*) { | |
return PyArray_FromMatrixXd(self->gsm->scales()); | |
} | |
/** | |
* Replace linear basis. | |
*/ | |
static int GSM_set_scales(GSMObject* self, PyObject* value, void*) { | |
if(!PyArray_Check(value)) { | |
PyErr_SetString(PyExc_TypeError, "Scales should be of type `ndarray`."); | |
return -1; | |
} | |
try { | |
self->gsm->setScales(PyArray_ToMatrixXd(value)); | |
} catch(Exception exception) { | |
PyErr_SetString(PyExc_RuntimeError, exception.message()); | |
return -1; | |
} | |
return 0; | |
} | |
static PyObject* GSM_variance(GSMObject* self, PyObject*, PyObject*) { | |
try { | |
return PyFloat_FromDouble(self->gsm->variance()); | |
} catch(Exception exception) { | |
PyErr_SetString(PyExc_RuntimeError, exception.message()); | |
return 0; | |
} | |
} | |
static PyObject* GSM_normalize(GSMObject* self, PyObject*, PyObject*) { | |
try { | |
self->gsm->normalize(); | |
} catch(Exception exception) { | |
PyErr_SetString(PyExc_RuntimeError, exception.message()); | |
return 0; | |
} | |
Py_INCREF(Py_None); | |
return Py_None; | |
} | |
static PyObject* GSM_train(GSMObject* self, PyObject* args, PyObject* kwds) { | |
char* kwlist[] = {"data", "max_iter", "tol", 0}; | |
PyObject* data; | |
int max_iter = 100; | |
double tol = 1e-5; | |
// read arguments | |
if(!PyArg_ParseTupleAndKeywords(args, kwds, "O|id", kwlist, &data, &max_iter, &tol)) | |
return 0; | |
// make sure data is stored in NumPy array | |
if(!PyArray_Check(data)) { | |
PyErr_SetString(PyExc_TypeError, "Data has to be stored in a NumPy array."); | |
return 0; | |
} | |
try { | |
if(self->gsm->train(PyArray_ToMatrixXd(data), max_iter, tol)) { | |
Py_INCREF(Py_True); | |
return Py_True; | |
} else { | |
Py_INCREF(Py_False); | |
return Py_False; | |
} | |
} catch(Exception exception) { | |
PyErr_SetString(PyExc_RuntimeError, exception.message()); | |
return 0; | |
} | |
return 0; | |
} | |
static PyObject* GSM_posterior(GSMObject* self, PyObject* args, PyObject* kwds) { | |
char* kwlist[] = {"data", 0}; | |
PyObject* data; | |
// read arguments | |
if(!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &data)) | |
return 0; | |
// make sure data is stored in NumPy array | |
if(!PyArray_Check(data)) { | |
PyErr_SetString(PyExc_TypeError, "Data has to be stored in a NumPy array."); | |
return 0; | |
} | |
try { | |
return PyArray_FromMatrixXd(self->gsm->posterior(PyArray_ToMatrixXd(data))); | |
} catch(Exception exception) { | |
PyErr_SetString(PyExc_RuntimeError, exception.message()); | |
return 0; | |
} | |
} | |
static PyObject* GSM_sample(GSMObject* self, PyObject* args, PyObject* kwds) { | |
char* kwlist[] = {"num_samples", 0}; | |
int num_samples = 1; | |
if(!PyArg_ParseTupleAndKeywords(args, kwds, "|i", kwlist, &num_samples)) | |
return 0; | |
try { | |
return PyArray_FromMatrixXd(self->gsm->sample(num_samples)); | |
} catch(Exception exception) { | |
PyErr_SetString(PyExc_RuntimeError, exception.message()); | |
return 0; | |
} | |
} | |
static PyObject* GSM_sample_posterior(GSMObject* self, PyObject* args, PyObject* kwds) { | |
char* kwlist[] = {"data", 0}; | |
PyObject* data; | |
// read arguments | |
if(!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &data)) | |
return 0; | |
// make sure data is stored in NumPy array | |
if(!PyArray_Check(data)) { | |
PyErr_SetString(PyExc_TypeError, "Data has to be stored in a NumPy array."); | |
return 0; | |
} | |
try { | |
return PyArray_FromMatrixXd(self->gsm->samplePosterior(PyArray_ToMatrixXd(data))); | |
} catch(Exception exception) { | |
PyErr_SetString(PyExc_RuntimeError, exception.message()); | |
return 0; | |
} | |
} | |
static PyObject* GSM_loglikelihood(GSMObject* self, PyObject* args, PyObject* kwds) { | |
char* kwlist[] = {"data", 0}; | |
PyObject* data; | |
// read arguments | |
if(!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &data)) | |
return 0; | |
// make sure data is stored in NumPy array | |
if(!PyArray_Check(data)) { | |
PyErr_SetString(PyExc_TypeError, "Data has to be stored in a NumPy array."); | |
return 0; | |
} | |
try { | |
return PyArray_FromMatrixXd(self->gsm->logLikelihood(PyArray_ToMatrixXd(data))); | |
} catch(Exception exception) { | |
PyErr_SetString(PyExc_RuntimeError, exception.message()); | |
return 0; | |
} | |
} | |
static PyObject* GSM_energy(GSMObject* self, PyObject* args, PyObject* kwds) { | |
char* kwlist[] = {"data", 0}; | |
PyObject* data; | |
// read arguments | |
if(!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &data)) | |
return 0; | |
// make sure data is stored in NumPy array | |
if(!PyArray_Check(data)) { | |
PyErr_SetString(PyExc_TypeError, "Data has to be stored in a NumPy array."); | |
return 0; | |
} | |
try { | |
return PyArray_FromMatrixXd(self->gsm->energy(PyArray_ToMatrixXd(data))); | |
} catch(Exception exception) { | |
PyErr_SetString(PyExc_RuntimeError, exception.message()); | |
return 0; | |
} | |
} | |
static PyObject* GSM_energy_gradient(GSMObject* self, PyObject* args, PyObject* kwds) { | |
char* kwlist[] = {"data", 0}; | |
PyObject* data; | |
// read arguments | |
if(!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &data)) | |
return 0; | |
// make sure data is stored in NumPy array | |
if(!PyArray_Check(data)) { | |
PyErr_SetString(PyExc_TypeError, "Data has to be stored in a NumPy array."); | |
return 0; | |
} | |
try { | |
return PyArray_FromMatrixXd(self->gsm->energyGradient(PyArray_ToMatrixXd(data))); | |
} catch(Exception exception) { | |
PyErr_SetString(PyExc_RuntimeError, exception.message()); | |
return 0; | |
} | |
} | |
static PyGetSetDef GSM_getset[] = { | |
{"dim", (getter)GSM_dim, 0, 0}, | |
{"num_scales", (getter)GSM_num_scales, 0, 0}, | |
{"scales", (getter)GSM_scales, (setter)GSM_set_scales, 0}, | |
{0} | |
}; | |
static PyMethodDef GSM_methods[] = { | |
{"train", (PyCFunction)GSM_train, METH_VARARGS|METH_KEYWORDS, 0}, | |
{"posterior", (PyCFunction)GSM_posterior, METH_VARARGS|METH_KEYWORDS, 0}, | |
{"variance", (PyCFunction)GSM_variance, METH_NOARGS, 0}, | |
{"normalize", (PyCFunction)GSM_normalize, METH_NOARGS, 0}, | |
{"sample", (PyCFunction)GSM_sample, METH_VARARGS|METH_KEYWORDS, 0}, | |
{"sample_posterior", (PyCFunction)GSM_sample_posterior, METH_VARARGS|METH_KEYWORDS, 0}, | |
{"loglikelihood", (PyCFunction)GSM_loglikelihood, METH_VARARGS|METH_KEYWORDS, 0}, | |
{"energy", (PyCFunction)GSM_energy, METH_VARARGS|METH_KEYWORDS, 0}, | |
{"energy_gradient", (PyCFunction)GSM_energy_gradient, METH_VARARGS|METH_KEYWORDS, 0}, | |
{0} | |
}; | |
static PyTypeObject GSM_type = { | |
PyObject_HEAD_INIT(0) | |
0, /*ob_size*/ | |
"isa.GSM", /*tp_name*/ | |
sizeof(GSMObject), /*tp_basicsize*/ | |
0, /*tp_itemsize*/ | |
(destructor)GSM_dealloc, /*tp_dealloc*/ | |
0, /*tp_print*/ | |
0, /*tp_getattr*/ | |
0, /*tp_setattr*/ | |
0, /*tp_compare*/ | |
0, /*tp_repr*/ | |
0, /*tp_as_number*/ | |
0, /*tp_as_sequence*/ | |
0, /*tp_as_mapping*/ | |
0, /*tp_hash */ | |
0, /*tp_call*/ | |
0, /*tp_str*/ | |
0, /*tp_getattro*/ | |
0, /*tp_setattro*/ | |
0, /*tp_as_buffer*/ | |
Py_TPFLAGS_DEFAULT, /*tp_flags*/ | |
0, /*tp_doc*/ | |
0, /*tp_traverse*/ | |
0, /*tp_clear*/ | |
0, /*tp_richcompare*/ | |
0, /*tp_weaklistoffset*/ | |
0, /*tp_iter*/ | |
0, /*tp_iternext*/ | |
GSM_methods, /*tp_methods*/ | |
0, /*tp_members*/ | |
GSM_getset, /*tp_getset*/ | |
0, /*tp_base*/ | |
0, /*tp_dict*/ | |
0, /*tp_descr_get*/ | |
0, /*tp_descr_set*/ | |
0, /*tp_dictoffset*/ | |
(initproc)GSM_init, /*tp_init*/ | |
0, /*tp_alloc*/ | |
GSM_new, /*tp_new*/ | |
}; | |
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <Python.h> | |
#include <arrayobject.h> | |
#include <structmember.h> | |
#include <stdlib.h> | |
#include <time.h> | |
#include "isainterface.h" | |
#include "gsminterface.h" | |
PyMODINIT_FUNC initisa() { | |
// set random seed | |
timeval time; | |
gettimeofday(&time, 0); | |
srand(time.tv_usec * time.tv_sec); | |
// initialize NumPy | |
import_array(); | |
// create module object | |
PyObject* module = Py_InitModule("isa", 0); | |
// initialize types | |
if(PyType_Ready(&ISA_type) < 0) | |
return; | |
if(PyType_Ready(&GSM_type) < 0) | |
return; | |
// add types to module | |
Py_INCREF(&ISA_type); | |
PyModule_AddObject(module, "ISA", reinterpret_cast<PyObject*>(&ISA_type)); | |
Py_INCREF(&GSM_type); | |
PyModule_AddObject(module, "GSM", reinterpret_cast<PyObject*>(&GSM_type)); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy | |
from distutils.core import setup, Extension | |
from distutils.ccompiler import new_compiler | |
modules = [ | |
Extension('isa', | |
language='c++', | |
sources=[ | |
'code/isa/src/isa.cpp', | |
'code/isa/src/gsm.cpp', | |
'code/isa/src/utils.cpp', | |
'code/isa/src/module.cpp', | |
'code/isa/src/distribution.cpp'], | |
include_dirs=[ | |
'code', | |
'code/isa/include', | |
os.path.join(numpy.__path__[0], 'core/include/numpy')], | |
library_dirs=[], | |
libraries=[], | |
extra_link_args=[ | |
'-lgomp'], | |
extra_compile_args=[ | |
'-fopenmp', | |
'-Wno-parentheses', | |
'-Wno-write-strings'])] | |
setup( | |
name='isa', | |
version='0.1', | |
description='', | |
ext_modules=modules) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment