Created
March 17, 2017 17:36
-
-
Save papainge/436e827e319f68505ce2ab37d21be080 to your computer and use it in GitHub Desktop.
Implementations and tests for DTW
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.junit.Test; | |
import java.util.Arrays; | |
import java.util.Random; | |
import static org.junit.Assert.assertEquals; | |
public class DtwTest { | |
public double computeMine(double[] values1, double[] values2, double threshold) { | |
// | |
// Make sure value1 is the shortest array | |
// | |
int len1 = values1.length; | |
int len2 = values2.length; | |
if (len1 > len2) { | |
double[] tmp = values1; | |
values1 = values2; | |
values2 = tmp; | |
int tmpint = len1; | |
len1 = len2; | |
len2 = tmpint; | |
} | |
// | |
// Now run DTW. | |
// We allocate two columns so we can run DTW, only allocating | |
// | |
double[][] dtw = new double[len1 + 1][len2 + 1]; | |
for (int i = 0; i < len1 + 1; i++) { | |
Arrays.fill(dtw[i], Double.POSITIVE_INFINITY); | |
} | |
dtw[0][0] = 0.0; | |
int w = len2; | |
boolean belowThreshold = false; | |
for (int i = 1; i < len1 + 1; i++) { | |
int start = Math.max(1, i - w); | |
int end = Math.min(len2 + 1, i + w); | |
belowThreshold = false; | |
for (int j = start; j < end; j++) { | |
// | |
// Compute distance. | |
// DTW simply considers the delta in values, not the delta in indices | |
// | |
double d = Math.abs(values1[i-1] - values2[j-1]); | |
// | |
// Extract surrounding values | |
// | |
double left = dtw[i - 1][j]; | |
double upper = dtw[i][j - 1]; | |
double diag = dtw[i - 1][j - 1]; | |
dtw[i][j] = d + Math.min(left, Math.min(upper, diag)); | |
if (!belowThreshold && (0.0D == threshold || dtw[i][j] <= threshold)) { | |
belowThreshold = true; | |
} | |
} | |
// Exit if no value is below threshold | |
if (!belowThreshold) { | |
break; | |
} | |
} | |
if (!belowThreshold) { | |
return -1.0D; | |
} | |
return dtw[len1][len2]; | |
} | |
public double compute(double[] values1, int offset1, int len1, double[] values2, int offset2, int len2, double threshold) { | |
// | |
// Make sure value1 is the shortest array | |
// | |
if (len1 > len2) { | |
double[] tmp = values1; | |
values1 = values2; | |
values2 = tmp; | |
int tmpint = offset1; | |
offset1 = offset2; | |
offset2 = tmpint; | |
tmpint = len1; | |
len1 = len2; | |
len2 = tmpint; | |
} | |
// | |
// Now run DTW. | |
// We allocate two columns so we can run DTW, only allocating | |
// | |
double[] a = new double[len1]; | |
double[] b = new double[len1]; | |
int w = values2.length; | |
boolean belowThreshold = false; | |
for (int i = 0; i < len2; i++) { | |
int start = Math.max(0, i - w); | |
int end = Math.min(len1 - 1, i + w); | |
belowThreshold = false; | |
for (int j = start; j <= end; j++) { | |
// | |
// Compute distance. | |
// DTW simply considers the delta in values, not the delta in indices | |
// | |
double d = Math.abs(values1[offset1 + j] - values2[offset2 + i]); | |
// | |
// Extract surrounding values | |
// | |
double left = a[j]; | |
double bottom = j > start ? b[j - 1] : Double.POSITIVE_INFINITY; | |
double sw = j > start ? a[j - 1] : Double.POSITIVE_INFINITY; | |
b[j] = d + Math.min(left, Math.min(bottom, sw)); | |
//System.out.println(i + " " + j + " " + d + " " + left + " " + bottom + " " + sw + " >>> " + b[j]); | |
if (!belowThreshold && (0.0D == threshold || b[j] <= threshold)) { | |
belowThreshold = true; | |
} | |
} | |
// Exit if no value is below threshold | |
if (!belowThreshold) { | |
break; | |
} | |
// Shift b into a | |
double[] tmp = a; | |
a = b; | |
b = tmp; | |
} | |
if (!belowThreshold) { | |
return -1.0D; | |
} | |
double dtwDist = Double.POSITIVE_INFINITY; | |
for (int i = 0; i < a.length; i++) { | |
if (a[i] < dtwDist) { | |
dtwDist = a[i]; | |
} | |
} | |
return dtwDist; | |
} | |
@Test | |
public void testComputeSameSeries() { | |
Random rand = new Random(1L); | |
double[] a = new double[50]; | |
for (int i = 0; i < a.length; i++) { | |
a[i] = rand.nextGaussian(); | |
} | |
assertEquals(computeMine(a, a, 0.0), 0.0, 1e-5); | |
assertEquals(compute(a, 0, a.length, a, 0, a.length, 0.0), 0.0, 1e-5); | |
} | |
@Test | |
public void testComputeSymmetric() { | |
double[] global = new double[]{0.008798913914705466, 0.008052009527598209, 0.00836504210740137, 0.0065654853519703585, 0.012582283303258168, 0.021248433719598687, 0.0051440693332960065, 0.010232085648249326, 0.0053947364301326125, 0, 0.47224753354072735, 0.9486669017930869, 0.9129323810856173, 0.9915809482571225, 0.9250711869932177, 0.9457249751954055, 0.9606886284144219, 0.9627061725417592, 0.9472936670574867, 0.9525507490007642, 0.9846334675674775, 0.8489200374448909, 0.9673176677158484, 0.9423086535125491, 1, 0.9480755199197319, 0.9590455052867518, 0.918655451850685, 0.9390281455702689, 0.9168174809823195, 0.8976719489194196, 0.9837622226496971, 0.8721874653794999, 0.7781882249351341, 0.8678270455332947, 0.8479609931352973, 0.848813502573319, 0.8637748937327843, 0.8712330025655161, 0.818079460940376, 0.9510091151350031, 0.8468076396974356, 0.8437379026547026, 0.8524615912092984, 0.7977192219626514, 0.7354493449722723, 0.8267995857443761, 0.7808016218307058, 0.82027655468507, 0.8463183479729905, 0.8201550224531412, 0.825225738275906, 0.7374624832740365, 0.4049362757456661, 0.39708769426661755, 0.4108926735152398, 0.4107776697347616, 0.43385394640480834, 0.418079680785449, 0.4328941214655191, 0.13958909603219796, 0.02187134033769401, 0.023938458857408567, 0.018985222365855484, 0.027484655049983795, 0.022511085112920767, 0.03616202243281704, 0.03005964384741452, 0.03317153126580197, 0.0313716325877124, 0.03482260006389524, 0.033120022702661456}; | |
double[] po10 = new double[]{0.30466309264515706, 0.1944024813971762, 0.21484484780279942, 0.21673313469222616, 0.09822491880730029, 0.08194634559820839, 0.19133247402896017, 0.04199095111960927, 0.09600572672813405, 0.031879824091934295, 0.5088930228684865, 0.9958187536587756, 1, 0.8963999941739437, 0.9619069193648051, 0.9386548459139162, 0.9301018750011757, 0.9167065774543025, 0.9673418517000545, 0.8955885743294247, 0.9772091860728065, 0.9189796959923074, 0.9821856851177897, 0.8489456315479653, 0.9593908820522833, 0.9215583908994651, 0.8948568610229924, 0.876953216082639, 0.9235596059381347, 0.8818457241000301, 0.8993712782242151, 0.9008561163418505, 0.7325343675981016, 0.6454674022877228, 0.7042270646344815, 0.7092341154727027, 0.8115084980870579, 0.74195243298993, 0.7384736914813406, 0.7816201912423166, 0.9122212019024518, 0.726844021832641, 0.7816424444997712, 0.776822890183573, 0.7463782047101023, 0.6420152504218215, 0.6922547012814555, 0.6965342750169002, 0.8104808310007804, 0.8417053051876552, 0.7727738622093474, 0.7947177772067585, 0.6393263578249365, 0.33730540206318116, 0.44018034226684766, 0.3426637256304818, 0.33746124879345696, 0.3987818755155467, 0.3667894617853437, 0.37800866647487835, 0.201654967958607, 0.04923293445595179, 0.04569114882381592, 0.04275887160539717, 0.04266842555690243, 0, 0.09998288534008508, 0.0261848221405116, 0.023015719666605452, 0.0409478943024898, 0.05516380985154353, 0.033736027210345014}; | |
double[] eth161_1_2 = new double[]{0.7442107467644723, 0.14014923206672353, 0.7893553456424499, 0.6975018906145588, 0.34464625347350886, 0.8582947573323689, 0.8073829604650142, 0.5780877604310494, 0.14641671647612683, 0.17084708323553713, 0.269530171749867, 0.5987451904318051, 0.5801880873892526, 0.13001646477507872, 0.7619322536708423, 0.8754148860638296, 0.2368184965008914, 0.08574298079046734, 0.12840706783014486, 0.7455920238842253, 1, 0.0711243082054153, 0.2428945533006623, 0.94792791731897, 0.6147675031751185, 0.10335548591881703, 0.8036568646535238, 0.758533974755294, 0.5331107124109977, 0.8999902639025142, 0.7936632110914822, 0.4205149151655969, 0.8721986086169509, 0.5474846864645251, 0.2605555766040642, 0.4355202422818324, 0.14926522894547817, 0.7898776837016848, 0.6878753395550866, 0.08982812841066397, 0.3219067127724229, 0.7945019669085895, 0.849655453868014, 0.34890292078980983, 0.7467924441381396, 0.6890759906532694, 0.38020163193624257, 0.8047286240948668, 0.7677669700018064, 0.1467041680875612, 0.7889614796325043, 0.67276056941163, 0.5058161309627242, 0.5879249749202129, 0.9096771146109283, 0.6584000384297996, 0.7845756982312924, 0.7846837898576915, 0.4678261149484802, 0.7890903532546358, 0.730851544878136, 0.09072518563101385, 0.8780652043009288, 0.601567438354048, 0, 0.644565463269724, 0.9836152593505539, 0.8547914575001178, 0.3712852095460953, 0.7918123979319661, 0.5663240098507405, 0.04179745196789641}; | |
double[] eth1_8 = new double[]{0.0053324347902166174, 0.0035665745179808248, 0.006497639423487071, 0.0057654474537357815, 0.02174078470042538, 0.007058746401862991, 0.002426315311281271, 0.008424830813677302, 0.0047777409103846175, 0, 0.4429601536950297, 0.814179077099929, 0.7272503799114347, 0.7909147150487464, 0.8052750943536133, 0.827882408049785, 0.8549813630263762, 0.8555314248985835, 0.8105389259498434, 0.8897782584407541, 0.9909421326416074, 0.9405497718486087, 0.9200446082172594, 0.7635146368672826, 0.948959576944235, 1, 0.8728439254081488, 0.7844944940407534, 0.9207279577357137, 0.9305722259503477, 0.7750701585473521, 0.9243802714464977, 0.8203935301336704, 0.6192934949010848, 0.790743831767644, 0.7666322241119, 0.7331052656724468, 0.777814338379208, 0.8307605087593654, 0.864595422880208, 0.8869865076412065, 0.7369835602584996, 0.8106869983116548, 0.8308423068760392, 0.8223124499601715, 0.7210886993859145, 0.7549518725854494, 0.7499378602974681, 0.873017456235872, 0.7412882566820548, 0.8456936348898456, 0.8166915419935223, 0.6585020073838521, 0.44839171842860526, 0.46016170247160254, 0.4740960953041661, 0.4709345487346007, 0.4335646214964055, 0.44935269200255545, 0.46770311121903557, 0.19056131771694082, 0.02015374675601678, 0.01766385305312402, 0.013491316146783205, 0.023905481233240444, 0.022855248993903267, 0.031114368512608473, 0.027454970453486056, 0.027435593015017495, 0.027172636281775273, 0.03505439088581834, 0.03233120417538101}; | |
double[] eth159_1_7 = new double[]{0.17016448546315915, 0.12079961544505, 0.13405675940348458, 0.036411601892918824, 0.09838532527278734, 0.10869454966959739, 0.08953006117432556, 0.1052737866845647, 0.195471743348033, 0.16143906224150467, 0.07090429186401533, 0.08089157235931134, 0.08713357074913494, 0.0835591253443369, 0.4580032674543874, 0.4118916881195556, 0.4737552585866626, 0.4140111763670664, 0.5929302847116354, 0.7712664797916133, 0.713054818976024, 0.863583800095225, 0.8994371609822932, 0.8295884181273315, 0.37915322066639773, 0.06158370231997737, 0.08205150080643812, 0.1301185020217338, 0.09513556513158561, 0, 0.12070910896054486, 0.4159065790322436, 1, 0.731790118162259, 0.850774993444614, 0.6104404857976053, 0.3432107655116547, 0.03552776183611116, 0.15640072948475728, 0.28673908043833385, 0.49853758716856467, 0.03404958617270257, 0.0762141856100671, 0.021294485297431345, 0.11772335029502662, 0.1219532304522296, 0.08734544480946556, 0.028896157744287644, 0.10969635136752222, 0.09158749495288243, 0.19619263850410018, 0.24800961536903943, 0.13057373427055244, 0.13897667362543434, 0.20881374492365165, 0.14273832083182228, 0.15294736090526093, 0.13525369623872252, 0.23268199539426254, 0.2966893940813079, 0.43039986961085086, 0.06810610882218915, 0.08219334552630271, 0.16431587253214588, 0.1459307615840305, 0.00974907274465963, 0.11432264909614179, 0.05211179789615416, 0.1950343299527493, 0.2239922098796779, 0.05440270434538704, 0.8991471165668278}; | |
System.out.println("Should be correlated"); | |
System.out.println(computeMine(po10, global, 0.0) + " " + compute(po10, 0, po10.length, global, 0, global.length, 0.0)); | |
System.out.println(computeMine(eth1_8, global, 0.0) + " " + compute(global, 0, global.length, eth1_8, 0, eth1_8.length, 0.0)); | |
System.out.println("Should not be correlated"); | |
System.out.println(computeMine(global, eth161_1_2, 0.0) + " " + compute(eth161_1_2, 0, eth161_1_2.length, global, 0, global.length, 0.0)); | |
System.out.println(computeMine(global, eth159_1_7, 0.0) + " " + compute(global, 0, global.length, eth159_1_7, 0, eth159_1_7.length, 0.0)); | |
assertEquals(computeMine(global, eth161_1_2, 0.0), computeMine(eth161_1_2, global, 0.0), 1e-5); | |
assertEquals(compute(global, 0, global.length, eth161_1_2, 0, eth161_1_2.length, 0.0), compute(eth161_1_2,0, eth161_1_2.length, global,0, global.length, 0.0), 1e-5); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment