Skip to content

Instantly share code, notes, and snippets.

@kylemcdonald
Last active April 28, 2021 16:28
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save kylemcdonald/76b6f18fb4026e01196282b59bd31e7e to your computer and use it in GitHub Desktop.
Save kylemcdonald/76b6f18fb4026e01196282b59bd31e7e to your computer and use it in GitHub Desktop.
DTW MSE numba function for use with UMAP.
# based on https://github.com/kylerbrown/ezdtw
# with modifications to be fully njit-able
import numpy as np
from numba import njit
@njit
def sqeuclidean(a, b):
return np.sum((a - b)**2)
@njit
def cdist_jit(a, b):
na = a.shape[0]
nb = b.shape[0]
m = a.shape[1]
distances = np.empty((na, nb), dtype=a.dtype)
for i in range(na):
for j in range(nb):
distances[i, j] = sqeuclidean(a[i], b[j])
return np.sqrt(distances)
@njit
def dtw_distance(distances):
'''calculate minimum cumulative distance'''
DTW = np.empty_like(distances)
DTW[:, 0] = np.inf
DTW[0, :] = np.inf
DTW[0, 0] = 0
for i in range(1, DTW.shape[0]):
for j in range(1, DTW.shape[1]):
DTW[i, j] = distances[i, j] + min(DTW[i-1, j], # insertion
DTW[i, j-1], # deletion
DTW[i-1, j-1] # match
)
return DTW
@njit
def backtrack(DTW):
'''compute DTW backtrace
DTW: a matrix of cumulative DTW paths
returns (p, q): x and y index lists of the optimal DTW path'''
i, j = DTW.shape[0] - 1, DTW.shape[1] - 1
p, q = [i], [j]
while i > 0 and j > 0:
v0 = DTW[i - 1, j - 1]
v1 = DTW[i, j - 1]
v2 = DTW[i - 1, j]
if v0 <= v1 and v0 <= v2: # v0 argmin
i -= 1
j -= 1
elif v1 <= v0 and v1 <= v2: # v1 argmin
j -= 1
else: # v2 argmin
i -= 1
p.append(i)
q.append(j)
p.reverse()
q.reverse()
return p, q
@njit
def dtw(a, b):
'''perform dynamic time warping on two matricies a and b
first dimension must be time, second dimension shapes must be equal
returns:
trace_x, trace_y -- the warp path as two lists of indicies. Suitable for use in
an iterpolation function such as numpy.interp
to warp values from a to b, use: numpy.interp(warpable_values, trace_x, trace_y)
to warp values from b to a, use: numpy.interp(warpable_values, trace_y, trace_x)
'''
distance = cdist_jit(a, b)
cum_min_dist = dtw_distance(distance)
trace_x, trace_y = backtrack(cum_min_dist)
return trace_x, trace_y
def build_dtw_mse(shape):
'''
First build a dtw with `dtw_metric = build_dtw_mse(x[0].shape),
then umap.UMAP(metric=dtw_metric).fit_transform(x.reshape(len(x), -1))
'''
@njit
def dtw_mse(a_flat, b_flat):
a = a_flat.reshape(*shape)
b = b_flat.reshape(*shape)
trace0, trace1 = dtw(a, b)
# using np.array is easy, but another way might be faster
aw = a[np.array(trace0)]
bw = b[np.array(trace1)]
dist = np.square(aw - bw).mean()
return dist
return dtw_mse
@marcdemers
Copy link

You should rename the function dtw, or change your example in build_dtw_mse's docstring, as you get a RecursionError when declaring
dtw = build_dtw_mse(x[0].shape). The variable you're assigning to (dtw) is the same as the function's name.
Great work by the way!

@kylemcdonald
Copy link
Author

@marcdemers thanks! i changed the docstring.

@NimaSarajpoor
Copy link

Thanks for your hard work. I have three questions:

(1) Have you tested your method? (just to make sure it calculated DTW correctly)

(2) You mentioned:
"First build a dtw with `dtw_metric = build_dtw_mse(x[0].shape)"
and then
"umap.UMAP(metric=dtw_metric).fit_transform(x.reshape(len(x), -1))"

I cannot understand what is x? is it the same as my_data (which is samples-by-features numpy array)? we usually use:
model = umap.UMAP(metric=dtw_metric)
model.fit_transform(my_data)
but you use "x.reshape(len(x),-1)" instead. Can you elaborate a liitle bit so you make my life a little bit easier to understand your code. (I am new to programming)

(3) In one of the issues mentioned in UMAP:
https://github.com/lmcinnes/umap/issues/348

it says we can use metric='precomputed' and then feed our model with distance_matric. So, we can calculate dtw_similarity_matrix of our data (anyway we would like) and then use:
umap.UMAP(metric='precomputed').fit_transform(dtw_similarity_matrix)

Have you seen this before? I couldn't find the 'precomputed' as an eligible input for the argument "metric".

Thanks,
Nima

@msseibel
Copy link

Note that this code was developed for dynamic time warping of spectrograms and not for pure time series.

More specifically, the code quietly assumes len(a.shape)==2.

If someone is interested in a simple implementation of DTW for time series have a look at this repo. It is not optimized for parallel computing and a fairly old implementation.

And as a slightly related bonus, here is a implementation for DTW of multivariate time series. I haven't tested it yet, but the description states support for numba:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment