-
-
Save kylemcdonald/76b6f18fb4026e01196282b59bd31e7e to your computer and use it in GitHub Desktop.
# based on https://github.com/kylerbrown/ezdtw | |
# with modifications to be fully njit-able | |
import numpy as np | |
from numba import njit | |
@njit | |
def sqeuclidean(a, b): | |
return np.sum((a - b)**2) | |
@njit | |
def cdist_jit(a, b): | |
na = a.shape[0] | |
nb = b.shape[0] | |
m = a.shape[1] | |
distances = np.empty((na, nb), dtype=a.dtype) | |
for i in range(na): | |
for j in range(nb): | |
distances[i, j] = sqeuclidean(a[i], b[j]) | |
return np.sqrt(distances) | |
@njit | |
def dtw_distance(distances): | |
'''calculate minimum cumulative distance''' | |
DTW = np.empty_like(distances) | |
DTW[:, 0] = np.inf | |
DTW[0, :] = np.inf | |
DTW[0, 0] = 0 | |
for i in range(1, DTW.shape[0]): | |
for j in range(1, DTW.shape[1]): | |
DTW[i, j] = distances[i, j] + min(DTW[i-1, j], # insertion | |
DTW[i, j-1], # deletion | |
DTW[i-1, j-1] # match | |
) | |
return DTW | |
@njit | |
def backtrack(DTW): | |
'''compute DTW backtrace | |
DTW: a matrix of cumulative DTW paths | |
returns (p, q): x and y index lists of the optimal DTW path''' | |
i, j = DTW.shape[0] - 1, DTW.shape[1] - 1 | |
p, q = [i], [j] | |
while i > 0 and j > 0: | |
v0 = DTW[i - 1, j - 1] | |
v1 = DTW[i, j - 1] | |
v2 = DTW[i - 1, j] | |
if v0 <= v1 and v0 <= v2: # v0 argmin | |
i -= 1 | |
j -= 1 | |
elif v1 <= v0 and v1 <= v2: # v1 argmin | |
j -= 1 | |
else: # v2 argmin | |
i -= 1 | |
p.append(i) | |
q.append(j) | |
p.reverse() | |
q.reverse() | |
return p, q | |
@njit | |
def dtw(a, b): | |
'''perform dynamic time warping on two matricies a and b | |
first dimension must be time, second dimension shapes must be equal | |
returns: | |
trace_x, trace_y -- the warp path as two lists of indicies. Suitable for use in | |
an iterpolation function such as numpy.interp | |
to warp values from a to b, use: numpy.interp(warpable_values, trace_x, trace_y) | |
to warp values from b to a, use: numpy.interp(warpable_values, trace_y, trace_x) | |
''' | |
distance = cdist_jit(a, b) | |
cum_min_dist = dtw_distance(distance) | |
trace_x, trace_y = backtrack(cum_min_dist) | |
return trace_x, trace_y | |
def build_dtw_mse(shape): | |
''' | |
First build a dtw with `dtw_metric = build_dtw_mse(x[0].shape), | |
then umap.UMAP(metric=dtw_metric).fit_transform(x.reshape(len(x), -1)) | |
''' | |
@njit | |
def dtw_mse(a_flat, b_flat): | |
a = a_flat.reshape(*shape) | |
b = b_flat.reshape(*shape) | |
trace0, trace1 = dtw(a, b) | |
# using np.array is easy, but another way might be faster | |
aw = a[np.array(trace0)] | |
bw = b[np.array(trace1)] | |
dist = np.square(aw - bw).mean() | |
return dist | |
return dtw_mse |
@marcdemers thanks! i changed the docstring.
Thanks for your hard work. I have three questions:
(1) Have you tested your method? (just to make sure it calculated DTW correctly)
(2) You mentioned:
"First build a dtw with `dtw_metric = build_dtw_mse(x[0].shape)"
and then
"umap.UMAP(metric=dtw_metric).fit_transform(x.reshape(len(x), -1))"
I cannot understand what is x? is it the same as my_data (which is samples-by-features numpy array)? we usually use:
model = umap.UMAP(metric=dtw_metric)
model.fit_transform(my_data)
but you use "x.reshape(len(x),-1)" instead. Can you elaborate a liitle bit so you make my life a little bit easier to understand your code. (I am new to programming)
(3) In one of the issues mentioned in UMAP:
https://github.com/lmcinnes/umap/issues/348
it says we can use metric='precomputed' and then feed our model with distance_matric. So, we can calculate dtw_similarity_matrix of our data (anyway we would like) and then use:
umap.UMAP(metric='precomputed').fit_transform(dtw_similarity_matrix)
Have you seen this before? I couldn't find the 'precomputed' as an eligible input for the argument "metric".
Thanks,
Nima
Note that this code was developed for dynamic time warping of spectrograms and not for pure time series.
More specifically, the code quietly assumes len(a.shape)==2
.
If someone is interested in a simple implementation of DTW for time series have a look at this repo. It is not optimized for parallel computing and a fairly old implementation.
And as a slightly related bonus, here is a implementation for DTW of multivariate time series. I haven't tested it yet, but the description states support for numba:
You should rename the function dtw, or change your example in build_dtw_mse's docstring, as you get a RecursionError when declaring
dtw = build_dtw_mse(x[0].shape). The variable you're assigning to (dtw) is the same as the function's name.
Great work by the way!