Skip to content

Instantly share code, notes, and snippets.

@non117
Last active December 11, 2015 21:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save non117/4662298 to your computer and use it in GitHub Desktop.
Save non117/4662298 to your computer and use it in GitHub Desktop.
DPマッチングを計算するPythonのC++拡張.
#include<Python/Python.h>
#include<vector>
#include<map>
#include<queue>
#include<functional>
using namespace std;
struct Point{
short x,y;
Point() : x(0),y(0) {}
Point(int x,int y) : x(x),y(y) {}
};
typedef struct node{
float cost;
Point prev;
Point cur;
node() :cost((float)1e30) {}
node(float cost,const Point& prev,const Point& cur) : cost(cost),prev(prev),cur(cur) {}
bool operator >(const node& r) const{ return cost > r.cost; }
} Node;
struct result{
vector<float> array;
vector<short> delay;
result(int size) : array(size), delay(size) {}
};
float sd(float x,float y){
float diff = x - y;
return diff * diff;
}
result dynamic_timewarp(float a[], float b[], int m, int n){
vector<vector<node> > path_matrix(m, vector<node>(n));
int x, y;
Node point(sd(a[0],b[0]), Point(-1,-1), Point(0,0));
path_matrix[0][0] = point;
//ダイクストラ法
priority_queue<Node, vector<Node>, greater<Node> > p_queue;
p_queue.push(point);
while(!p_queue.empty()){
Node cur_node = p_queue.top();
p_queue.pop();
if(path_matrix[cur_node.cur.x][cur_node.cur.y].cost < cur_node.cost){
continue;
}
if(cur_node.cur.x == m-1 && cur_node.cur.y == n-1){
break;
}
static const short dir_x[] = {0,1,1}, dir_y[] = {1,1,0};
for(int i = 0; i < 3; i++){
int nx = cur_node.cur.x + dir_x[i];
int ny = cur_node.cur.y + dir_y[i];
float add_cost = sd(a[nx],b[ny]);
if(nx < m && ny < n && path_matrix[nx][ny].cost > cur_node.cost + add_cost){
path_matrix[nx][ny].cost = cur_node.cost + add_cost;
path_matrix[nx][ny].prev = cur_node.cur;
p_queue.push(Node(path_matrix[nx][ny].cost,cur_node.cur,Point(nx,ny)));
}
}
}
vector<Point> min_path;
x = m - 1;
y = n - 1;
while (x != -1) {
Node p = path_matrix[x][y];
min_path.push_back(Point(x,y));
x = p.prev.x;
y = p.prev.y;
}
result r(m);
for (int k=0; k < min_path.size(); k++) {
int i = min_path[k].x;
int j = min_path[k].y;
r.array[i] = b[j];
r.delay[i] = i - j;
}
return r;
}
static PyObject* dtw(PyObject* self, PyObject* args){
PyObject *lisa, *lisb;
int m, n;
if (!PyArg_ParseTuple(args, "OO", &lisa, &lisb)) {
return NULL;
}
m = (int)PyList_Size(lisa);
n = (int)PyList_Size(lisb);
float a[m],b[n];
for (int i=0; i < m; i++) {
a[i] = PyFloat_AsDouble(PyList_GetItem(lisa, (Py_ssize_t)i));
}
for (int i=0; i < n; i++) {
b[i] = PyFloat_AsDouble(PyList_GetItem(lisb, (Py_ssize_t)i));
}
Py_DECREF(lisa); Py_DECREF(lisb);
result temp = dynamic_timewarp(a, b, m, n);
PyObject *array = PyList_New((Py_ssize_t)temp.array.size());
PyObject *delay = PyList_New((Py_ssize_t)temp.delay.size());
for (int i=0; i < m; i++) {
PyList_SetItem(array, (Py_ssize_t)i, PyFloat_FromDouble(temp.array[i]));
}
for (int i=0; i < m; i++) {
PyList_SetItem(delay, (Py_ssize_t)i, PyFloat_FromDouble(temp.delay[i]));
}
return Py_BuildValue("OO",array, delay);
}
static PyMethodDef methods[] = {
{"dtw", dtw, METH_VARARGS},
{NULL}
};
#ifdef __cplusplus
extern "C" {
#endif
void initdtw(){
Py_InitModule("dtw", methods);
}
#ifdef __cplusplus
} // extern "C"
#endif
from distutils.core import setup, Extension
setup(name="dtw", version="1.0",
ext_modules=[Extension("dtw", ["dtw.cpp"])])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment