Skip to content

Instantly share code, notes, and snippets.

@papainge
Created March 17, 2017 17:36
Show Gist options
  • Save papainge/436e827e319f68505ce2ab37d21be080 to your computer and use it in GitHub Desktop.
Save papainge/436e827e319f68505ce2ab37d21be080 to your computer and use it in GitHub Desktop.
Implementations and tests for DTW
import org.junit.Test;
import java.util.Arrays;
import java.util.Random;
import static org.junit.Assert.assertEquals;
public class DtwTest {
public double computeMine(double[] values1, double[] values2, double threshold) {
//
// Make sure value1 is the shortest array
//
int len1 = values1.length;
int len2 = values2.length;
if (len1 > len2) {
double[] tmp = values1;
values1 = values2;
values2 = tmp;
int tmpint = len1;
len1 = len2;
len2 = tmpint;
}
//
// Now run DTW.
// We allocate two columns so we can run DTW, only allocating
//
double[][] dtw = new double[len1 + 1][len2 + 1];
for (int i = 0; i < len1 + 1; i++) {
Arrays.fill(dtw[i], Double.POSITIVE_INFINITY);
}
dtw[0][0] = 0.0;
int w = len2;
boolean belowThreshold = false;
for (int i = 1; i < len1 + 1; i++) {
int start = Math.max(1, i - w);
int end = Math.min(len2 + 1, i + w);
belowThreshold = false;
for (int j = start; j < end; j++) {
//
// Compute distance.
// DTW simply considers the delta in values, not the delta in indices
//
double d = Math.abs(values1[i-1] - values2[j-1]);
//
// Extract surrounding values
//
double left = dtw[i - 1][j];
double upper = dtw[i][j - 1];
double diag = dtw[i - 1][j - 1];
dtw[i][j] = d + Math.min(left, Math.min(upper, diag));
if (!belowThreshold && (0.0D == threshold || dtw[i][j] <= threshold)) {
belowThreshold = true;
}
}
// Exit if no value is below threshold
if (!belowThreshold) {
break;
}
}
if (!belowThreshold) {
return -1.0D;
}
return dtw[len1][len2];
}
public double compute(double[] values1, int offset1, int len1, double[] values2, int offset2, int len2, double threshold) {
//
// Make sure value1 is the shortest array
//
if (len1 > len2) {
double[] tmp = values1;
values1 = values2;
values2 = tmp;
int tmpint = offset1;
offset1 = offset2;
offset2 = tmpint;
tmpint = len1;
len1 = len2;
len2 = tmpint;
}
//
// Now run DTW.
// We allocate two columns so we can run DTW, only allocating
//
double[] a = new double[len1];
double[] b = new double[len1];
int w = values2.length;
boolean belowThreshold = false;
for (int i = 0; i < len2; i++) {
int start = Math.max(0, i - w);
int end = Math.min(len1 - 1, i + w);
belowThreshold = false;
for (int j = start; j <= end; j++) {
//
// Compute distance.
// DTW simply considers the delta in values, not the delta in indices
//
double d = Math.abs(values1[offset1 + j] - values2[offset2 + i]);
//
// Extract surrounding values
//
double left = a[j];
double bottom = j > start ? b[j - 1] : Double.POSITIVE_INFINITY;
double sw = j > start ? a[j - 1] : Double.POSITIVE_INFINITY;
b[j] = d + Math.min(left, Math.min(bottom, sw));
//System.out.println(i + " " + j + " " + d + " " + left + " " + bottom + " " + sw + " >>> " + b[j]);
if (!belowThreshold && (0.0D == threshold || b[j] <= threshold)) {
belowThreshold = true;
}
}
// Exit if no value is below threshold
if (!belowThreshold) {
break;
}
// Shift b into a
double[] tmp = a;
a = b;
b = tmp;
}
if (!belowThreshold) {
return -1.0D;
}
double dtwDist = Double.POSITIVE_INFINITY;
for (int i = 0; i < a.length; i++) {
if (a[i] < dtwDist) {
dtwDist = a[i];
}
}
return dtwDist;
}
@Test
public void testComputeSameSeries() {
Random rand = new Random(1L);
double[] a = new double[50];
for (int i = 0; i < a.length; i++) {
a[i] = rand.nextGaussian();
}
assertEquals(computeMine(a, a, 0.0), 0.0, 1e-5);
assertEquals(compute(a, 0, a.length, a, 0, a.length, 0.0), 0.0, 1e-5);
}
@Test
public void testComputeSymmetric() {
double[] global = new double[]{0.008798913914705466, 0.008052009527598209, 0.00836504210740137, 0.0065654853519703585, 0.012582283303258168, 0.021248433719598687, 0.0051440693332960065, 0.010232085648249326, 0.0053947364301326125, 0, 0.47224753354072735, 0.9486669017930869, 0.9129323810856173, 0.9915809482571225, 0.9250711869932177, 0.9457249751954055, 0.9606886284144219, 0.9627061725417592, 0.9472936670574867, 0.9525507490007642, 0.9846334675674775, 0.8489200374448909, 0.9673176677158484, 0.9423086535125491, 1, 0.9480755199197319, 0.9590455052867518, 0.918655451850685, 0.9390281455702689, 0.9168174809823195, 0.8976719489194196, 0.9837622226496971, 0.8721874653794999, 0.7781882249351341, 0.8678270455332947, 0.8479609931352973, 0.848813502573319, 0.8637748937327843, 0.8712330025655161, 0.818079460940376, 0.9510091151350031, 0.8468076396974356, 0.8437379026547026, 0.8524615912092984, 0.7977192219626514, 0.7354493449722723, 0.8267995857443761, 0.7808016218307058, 0.82027655468507, 0.8463183479729905, 0.8201550224531412, 0.825225738275906, 0.7374624832740365, 0.4049362757456661, 0.39708769426661755, 0.4108926735152398, 0.4107776697347616, 0.43385394640480834, 0.418079680785449, 0.4328941214655191, 0.13958909603219796, 0.02187134033769401, 0.023938458857408567, 0.018985222365855484, 0.027484655049983795, 0.022511085112920767, 0.03616202243281704, 0.03005964384741452, 0.03317153126580197, 0.0313716325877124, 0.03482260006389524, 0.033120022702661456};
double[] po10 = new double[]{0.30466309264515706, 0.1944024813971762, 0.21484484780279942, 0.21673313469222616, 0.09822491880730029, 0.08194634559820839, 0.19133247402896017, 0.04199095111960927, 0.09600572672813405, 0.031879824091934295, 0.5088930228684865, 0.9958187536587756, 1, 0.8963999941739437, 0.9619069193648051, 0.9386548459139162, 0.9301018750011757, 0.9167065774543025, 0.9673418517000545, 0.8955885743294247, 0.9772091860728065, 0.9189796959923074, 0.9821856851177897, 0.8489456315479653, 0.9593908820522833, 0.9215583908994651, 0.8948568610229924, 0.876953216082639, 0.9235596059381347, 0.8818457241000301, 0.8993712782242151, 0.9008561163418505, 0.7325343675981016, 0.6454674022877228, 0.7042270646344815, 0.7092341154727027, 0.8115084980870579, 0.74195243298993, 0.7384736914813406, 0.7816201912423166, 0.9122212019024518, 0.726844021832641, 0.7816424444997712, 0.776822890183573, 0.7463782047101023, 0.6420152504218215, 0.6922547012814555, 0.6965342750169002, 0.8104808310007804, 0.8417053051876552, 0.7727738622093474, 0.7947177772067585, 0.6393263578249365, 0.33730540206318116, 0.44018034226684766, 0.3426637256304818, 0.33746124879345696, 0.3987818755155467, 0.3667894617853437, 0.37800866647487835, 0.201654967958607, 0.04923293445595179, 0.04569114882381592, 0.04275887160539717, 0.04266842555690243, 0, 0.09998288534008508, 0.0261848221405116, 0.023015719666605452, 0.0409478943024898, 0.05516380985154353, 0.033736027210345014};
double[] eth161_1_2 = new double[]{0.7442107467644723, 0.14014923206672353, 0.7893553456424499, 0.6975018906145588, 0.34464625347350886, 0.8582947573323689, 0.8073829604650142, 0.5780877604310494, 0.14641671647612683, 0.17084708323553713, 0.269530171749867, 0.5987451904318051, 0.5801880873892526, 0.13001646477507872, 0.7619322536708423, 0.8754148860638296, 0.2368184965008914, 0.08574298079046734, 0.12840706783014486, 0.7455920238842253, 1, 0.0711243082054153, 0.2428945533006623, 0.94792791731897, 0.6147675031751185, 0.10335548591881703, 0.8036568646535238, 0.758533974755294, 0.5331107124109977, 0.8999902639025142, 0.7936632110914822, 0.4205149151655969, 0.8721986086169509, 0.5474846864645251, 0.2605555766040642, 0.4355202422818324, 0.14926522894547817, 0.7898776837016848, 0.6878753395550866, 0.08982812841066397, 0.3219067127724229, 0.7945019669085895, 0.849655453868014, 0.34890292078980983, 0.7467924441381396, 0.6890759906532694, 0.38020163193624257, 0.8047286240948668, 0.7677669700018064, 0.1467041680875612, 0.7889614796325043, 0.67276056941163, 0.5058161309627242, 0.5879249749202129, 0.9096771146109283, 0.6584000384297996, 0.7845756982312924, 0.7846837898576915, 0.4678261149484802, 0.7890903532546358, 0.730851544878136, 0.09072518563101385, 0.8780652043009288, 0.601567438354048, 0, 0.644565463269724, 0.9836152593505539, 0.8547914575001178, 0.3712852095460953, 0.7918123979319661, 0.5663240098507405, 0.04179745196789641};
double[] eth1_8 = new double[]{0.0053324347902166174, 0.0035665745179808248, 0.006497639423487071, 0.0057654474537357815, 0.02174078470042538, 0.007058746401862991, 0.002426315311281271, 0.008424830813677302, 0.0047777409103846175, 0, 0.4429601536950297, 0.814179077099929, 0.7272503799114347, 0.7909147150487464, 0.8052750943536133, 0.827882408049785, 0.8549813630263762, 0.8555314248985835, 0.8105389259498434, 0.8897782584407541, 0.9909421326416074, 0.9405497718486087, 0.9200446082172594, 0.7635146368672826, 0.948959576944235, 1, 0.8728439254081488, 0.7844944940407534, 0.9207279577357137, 0.9305722259503477, 0.7750701585473521, 0.9243802714464977, 0.8203935301336704, 0.6192934949010848, 0.790743831767644, 0.7666322241119, 0.7331052656724468, 0.777814338379208, 0.8307605087593654, 0.864595422880208, 0.8869865076412065, 0.7369835602584996, 0.8106869983116548, 0.8308423068760392, 0.8223124499601715, 0.7210886993859145, 0.7549518725854494, 0.7499378602974681, 0.873017456235872, 0.7412882566820548, 0.8456936348898456, 0.8166915419935223, 0.6585020073838521, 0.44839171842860526, 0.46016170247160254, 0.4740960953041661, 0.4709345487346007, 0.4335646214964055, 0.44935269200255545, 0.46770311121903557, 0.19056131771694082, 0.02015374675601678, 0.01766385305312402, 0.013491316146783205, 0.023905481233240444, 0.022855248993903267, 0.031114368512608473, 0.027454970453486056, 0.027435593015017495, 0.027172636281775273, 0.03505439088581834, 0.03233120417538101};
double[] eth159_1_7 = new double[]{0.17016448546315915, 0.12079961544505, 0.13405675940348458, 0.036411601892918824, 0.09838532527278734, 0.10869454966959739, 0.08953006117432556, 0.1052737866845647, 0.195471743348033, 0.16143906224150467, 0.07090429186401533, 0.08089157235931134, 0.08713357074913494, 0.0835591253443369, 0.4580032674543874, 0.4118916881195556, 0.4737552585866626, 0.4140111763670664, 0.5929302847116354, 0.7712664797916133, 0.713054818976024, 0.863583800095225, 0.8994371609822932, 0.8295884181273315, 0.37915322066639773, 0.06158370231997737, 0.08205150080643812, 0.1301185020217338, 0.09513556513158561, 0, 0.12070910896054486, 0.4159065790322436, 1, 0.731790118162259, 0.850774993444614, 0.6104404857976053, 0.3432107655116547, 0.03552776183611116, 0.15640072948475728, 0.28673908043833385, 0.49853758716856467, 0.03404958617270257, 0.0762141856100671, 0.021294485297431345, 0.11772335029502662, 0.1219532304522296, 0.08734544480946556, 0.028896157744287644, 0.10969635136752222, 0.09158749495288243, 0.19619263850410018, 0.24800961536903943, 0.13057373427055244, 0.13897667362543434, 0.20881374492365165, 0.14273832083182228, 0.15294736090526093, 0.13525369623872252, 0.23268199539426254, 0.2966893940813079, 0.43039986961085086, 0.06810610882218915, 0.08219334552630271, 0.16431587253214588, 0.1459307615840305, 0.00974907274465963, 0.11432264909614179, 0.05211179789615416, 0.1950343299527493, 0.2239922098796779, 0.05440270434538704, 0.8991471165668278};
System.out.println("Should be correlated");
System.out.println(computeMine(po10, global, 0.0) + " " + compute(po10, 0, po10.length, global, 0, global.length, 0.0));
System.out.println(computeMine(eth1_8, global, 0.0) + " " + compute(global, 0, global.length, eth1_8, 0, eth1_8.length, 0.0));
System.out.println("Should not be correlated");
System.out.println(computeMine(global, eth161_1_2, 0.0) + " " + compute(eth161_1_2, 0, eth161_1_2.length, global, 0, global.length, 0.0));
System.out.println(computeMine(global, eth159_1_7, 0.0) + " " + compute(global, 0, global.length, eth159_1_7, 0, eth159_1_7.length, 0.0));
assertEquals(computeMine(global, eth161_1_2, 0.0), computeMine(eth161_1_2, global, 0.0), 1e-5);
assertEquals(compute(global, 0, global.length, eth161_1_2, 0, eth161_1_2.length, 0.0), compute(eth161_1_2,0, eth161_1_2.length, global,0, global.length, 0.0), 1e-5);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment