Created
August 11, 2016 14:54
-
-
Save jampekka/3fd5fa2a2d0388a9dd42291b783cb631 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy.interpolate | |
import scipy.signal | |
def match_sampled_brutal(ts1, haystack, ts2, needle, minlen=100): | |
dt = np.median(np.diff(ts1)) | |
ts1_base = ts1[0] | |
ts1 = ts1 - ts1_base | |
ts2_base = ts2[0] | |
ts2 = ts2 - ts2_base | |
needle = scipy.interpolate.interp1d(ts2, needle, bounds_error=False) | |
lags = np.arange(-ts1[-1], ts1[-1], dt) | |
corrs = [] | |
for lag in lags: | |
s = needle(ts1 - lag) | |
valid = ~np.isnan(s) | |
corr = scipy.stats.pearsonr(s[valid], haystack[valid])[0] | |
if np.sum(valid) < minlen: | |
corr = np.nan | |
corrs.append(corr) | |
return lags[np.nanargmax(corrs)] + (ts1_base - ts2_base) | |
def demo(): | |
import matplotlib.pyplot as plt | |
ts = np.arange(0, 100, 0.1) | |
n = len(ts) | |
signal = np.sin((ts/ts[-1]*np.pi*6)**2) | |
ts += 1334 | |
chunk = slice(int(n*0.6), int(n*0.8), 2) | |
ts2 = ts[chunk].copy() | |
true_lag = 34234 | |
ts2 -= true_lag | |
signal2 = signal[chunk].copy() + np.random.randn(len(ts2))*0.3 | |
best_lag = match_sampled_brutal(ts, signal, ts2, signal2) | |
print "Estimate:", best_lag, "Thruth:", true_lag | |
plt.plot(ts, signal) | |
plt.plot(ts2 + best_lag, signal2) | |
plt.show() | |
if __name__ == '__main__': demo() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment