-
-
Save qingkaikong/11381546 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# An pure python implemetation of Dynamic Time Warpping | |
# http://en.wikipedia.org/wiki/Dynamic_time_warping | |
class Dtw(object): | |
def __init__(self, seq1, seq2, distance_func=None): | |
''' | |
seq1, seq2 are two lists, | |
distance_func is a function for calculating | |
the local distance between two elements. | |
''' | |
self._seq1 = seq1 | |
self._seq2 = seq2 | |
self._distance_func = distance_func if distance_func else lambda: 0 | |
self._map = {(-1, -1): 0.0} | |
self._distance_matrix = {} | |
self._path = [] | |
def get_distance(self, i1, i2): | |
ret = self._distance_matrix.get((i1, i2)) | |
if not ret: | |
ret = self._distance_func(self._seq1[i1], self._seq2[i2]) | |
self._distance_matrix[(i1, i2)] = ret | |
return ret | |
def calculate_backward(self, i1, i2): | |
''' | |
Calculate the dtw distance between | |
seq1[:i1 + 1] and seq2[:i2 + 1] | |
''' | |
if self._map.get((i1, i2)) is not None: | |
return self._map[(i1, i2)] | |
if i1 == -1 or i2 == -1: | |
self._map[(i1, i2)] = float('inf') | |
return float('inf') | |
min_i1, min_i2 = min((i1 - 1, i2), (i1, i2 - 1), (i1 - 1, i2 - 1), | |
key=lambda x: self.calculate_backward(*x)) | |
self._map[(i1, i2)] = self.get_distance(i1, i2) + \ | |
self.calculate_backward(min_i1, min_i2) | |
return self._map[(i1, i2)] | |
def get_path(self): | |
''' | |
Calculate the path mapping. | |
Must be called after calculate() | |
''' | |
i1, i2 = (len(self._seq1) - 1, len(self._seq2) - 1) | |
while (i1, i2) != (-1, -1): | |
self._path.append((i1, i2)) | |
min_i1, min_i2 = min((i1 - 1, i2), (i1, i2 - 1), (i1 - 1, i2 - 1), | |
key=lambda x: self._map[x[0], x[1]]) | |
i1, i2 = min_i1, min_i2 | |
return self._path | |
def calculate(self): | |
return self.calculate_backward(len(self._seq1) - 1, | |
len(self._seq2) - 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment