Skip to content

Instantly share code, notes, and snippets.

Last active April 28, 2021 16:28
Show Gist options
  • Save kylemcdonald/76b6f18fb4026e01196282b59bd31e7e to your computer and use it in GitHub Desktop.
Save kylemcdonald/76b6f18fb4026e01196282b59bd31e7e to your computer and use it in GitHub Desktop.
DTW MSE numba function for use with UMAP.
# based on
# with modifications to be fully njit-able
import numpy as np
from numba import njit
def sqeuclidean(a, b):
return np.sum((a - b)**2)
def cdist_jit(a, b):
na = a.shape[0]
nb = b.shape[0]
m = a.shape[1]
distances = np.empty((na, nb), dtype=a.dtype)
for i in range(na):
for j in range(nb):
distances[i, j] = sqeuclidean(a[i], b[j])
return np.sqrt(distances)
def dtw_distance(distances):
'''calculate minimum cumulative distance'''
DTW = np.empty_like(distances)
DTW[:, 0] = np.inf
DTW[0, :] = np.inf
DTW[0, 0] = 0
for i in range(1, DTW.shape[0]):
for j in range(1, DTW.shape[1]):
DTW[i, j] = distances[i, j] + min(DTW[i-1, j], # insertion
DTW[i, j-1], # deletion
DTW[i-1, j-1] # match
return DTW
def backtrack(DTW):
'''compute DTW backtrace
DTW: a matrix of cumulative DTW paths
returns (p, q): x and y index lists of the optimal DTW path'''
i, j = DTW.shape[0] - 1, DTW.shape[1] - 1
p, q = [i], [j]
while i > 0 and j > 0:
v0 = DTW[i - 1, j - 1]
v1 = DTW[i, j - 1]
v2 = DTW[i - 1, j]
if v0 <= v1 and v0 <= v2: # v0 argmin
i -= 1
j -= 1
elif v1 <= v0 and v1 <= v2: # v1 argmin
j -= 1
else: # v2 argmin
i -= 1
return p, q
def dtw(a, b):
'''perform dynamic time warping on two matricies a and b
first dimension must be time, second dimension shapes must be equal
trace_x, trace_y -- the warp path as two lists of indicies. Suitable for use in
an iterpolation function such as numpy.interp
to warp values from a to b, use: numpy.interp(warpable_values, trace_x, trace_y)
to warp values from b to a, use: numpy.interp(warpable_values, trace_y, trace_x)
distance = cdist_jit(a, b)
cum_min_dist = dtw_distance(distance)
trace_x, trace_y = backtrack(cum_min_dist)
return trace_x, trace_y
def build_dtw_mse(shape):
First build a dtw with `dtw_metric = build_dtw_mse(x[0].shape),
then umap.UMAP(metric=dtw_metric).fit_transform(x.reshape(len(x), -1))
def dtw_mse(a_flat, b_flat):
a = a_flat.reshape(*shape)
b = b_flat.reshape(*shape)
trace0, trace1 = dtw(a, b)
# using np.array is easy, but another way might be faster
aw = a[np.array(trace0)]
bw = b[np.array(trace1)]
dist = np.square(aw - bw).mean()
return dist
return dtw_mse
Copy link

You should rename the function dtw, or change your example in build_dtw_mse's docstring, as you get a RecursionError when declaring
dtw = build_dtw_mse(x[0].shape). The variable you're assigning to (dtw) is the same as the function's name.
Great work by the way!

Copy link

@marcdemers thanks! i changed the docstring.

Copy link

Thanks for your hard work. I have three questions:

(1) Have you tested your method? (just to make sure it calculated DTW correctly)

(2) You mentioned:
"First build a dtw with `dtw_metric = build_dtw_mse(x[0].shape)"
and then
"umap.UMAP(metric=dtw_metric).fit_transform(x.reshape(len(x), -1))"

I cannot understand what is x? is it the same as my_data (which is samples-by-features numpy array)? we usually use:
model = umap.UMAP(metric=dtw_metric)
but you use "x.reshape(len(x),-1)" instead. Can you elaborate a liitle bit so you make my life a little bit easier to understand your code. (I am new to programming)

(3) In one of the issues mentioned in UMAP:

it says we can use metric='precomputed' and then feed our model with distance_matric. So, we can calculate dtw_similarity_matrix of our data (anyway we would like) and then use:

Have you seen this before? I couldn't find the 'precomputed' as an eligible input for the argument "metric".


Copy link

Note that this code was developed for dynamic time warping of spectrograms and not for pure time series.

More specifically, the code quietly assumes len(a.shape)==2.

If someone is interested in a simple implementation of DTW for time series have a look at this repo. It is not optimized for parallel computing and a fairly old implementation.

And as a slightly related bonus, here is a implementation for DTW of multivariate time series. I haven't tested it yet, but the description states support for numba:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment