Skip to content

Instantly share code, notes, and snippets.

@paultsw
Created August 21, 2017 16:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save paultsw/666711643442d8283252011a93b88241 to your computer and use it in GitHub Desktop.
Save paultsw/666711643442d8283252011a93b88241 to your computer and use it in GitHub Desktop.
Gaussian process regression for 1d signals with sklearn
"""
Fit a Gaussian process to a signal using SKLearn.
"""
import numpy as np
from sklearn import gaussian_process
from sklearn.gaussian_process import kernels as K
import matplotlib.pyplot as plt
from scipy.signal import resample
import argparse
def clip_signal(sig, tol):
"""Clip a signal within the bounds indicated by `tol`."""
centre = np.mean(sig)
return np.clip(sig, centre-tol, centre+tol)
def build_kernel():
"""
Construct a custom kernel.
[N.B.: if you want to experiment with different kernels, this is the only function you should change.]
"""
return K.RBF(length_scale=10.) + K.WhiteKernel(noise_level=20.)
def main():
### parse CLI args:
parser = argparse.ArgumentParser()
parser.add_argument("--clip", dest='clip', type=bool, default=False,
help="If True, clip any samples that fall outside of 3 STDVs. [Default: False]")
parser.add_argument("--subsample", dest='subsample', type=int, default=5000,
help="Number of points to subsample from read; if 0, use whole read. [Default: 5000]")
parser.add_argument("signal_file")
args = parser.parse_args()
### load signal:
signal = np.load(args.signal_file)
### optionally clip the open-pore signal at the start and end;
### (if this is done, remove everything 2 stdvs away from the mean.)
if args.clip: signal = clip_signal(signal, 3*np.std(signal))
### optionally subsample:
### (if S := subsample > 0, subsample S points from signal)
if args.subsample > 0: signal = resample(signal, args.subsample)
### compute signal statistics:
sig_max = np.amax(signal)
sig_min = np.amin(signal)
x_ticks = np.linspace(start=0, stop=(2*len(signal)), num=len(signal))
### construct kernel:
kernel = build_kernel()
### perform kriging on the signal data:
gpr = gaussian_process.GaussianProcessRegressor(
kernel=kernel,
optimizer='fmin_l_bfgs_b',
n_restarts_optimizer=5,
normalize_y=True)
_X = x_ticks.reshape(-1,1)
_y = signal.reshape(-1,1)
print("Fitting to dataset... Be patient, this may take a while.")
gpr.fit(_X,_y)
print("...Done. Generating predicted mean curve and plotting...")
predictions = gpr.predict(_X).reshape(-1)
### plot GPR predictions:
plt.plot(x_ticks, signal, 'o')
# (un-)comment the next two lines to hide/show 3sigma bounding lines:
#plt.plot(x_ticks, np.ones_like(x_ticks) * (np.mean(signal)+3*np.std(signal)), '-')
#plt.plot(x_ticks, np.ones_like(x_ticks) * (np.mean(signal)-3*np.std(signal)), '-')
plt.plot(x_ticks, predictions, '-')
plt.xlim([-1, len(signal)*2+1])
plt.ylim([sig_min-1, sig_max+1])
plt.show()
# run only when called from CLI:
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment