Last active
September 14, 2021 10:02
-
-
Save pastewka/bda88cce8b0d193a34c2336422829a44 to your computer and use it in GitHub Desktop.
Simple Gaussian process regression example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Simple Gaussian process regression example""" | |
import numpy as np | |
import matplotlib.pyplot as plt | |
def kernel(x1, x2, length_scale=1, signal_variance=1): | |
return signal_variance*np.exp(-(x1-x2)**2 / (2*length_scale**2)) | |
class GaussianProcessRegression: | |
def __init__(self, kernel): | |
self._kernel = kernel | |
def train(self, x, y, noise_variance=0.01): | |
# Covariance between observations | |
self._obs_cov = self._kernel(x.reshape(-1, 1), x.reshape(1, -1)) | |
# Add noise to observation covariance matrix | |
self._obs_cov += noise_variance * np.identity(len(x)) | |
# Compute kernel coefficients | |
self._coeff = np.linalg.solve(self._obs_cov, y) | |
# Store training set | |
self._x = x | |
self._y = y | |
def predict(self, x): | |
# Covariance between test outputs | |
test_cov = self._kernel(x.reshape(-1, 1), x.reshape(1, -1)) | |
# Covariance between observation and test outputs | |
obs_test_cov = self._kernel( | |
self._x.reshape(-1, 1), x.reshape(1, -1)) | |
# Compute predictive mean | |
pred_mean = self._coeff.dot(obs_test_cov) | |
# Compute predictive covariance | |
pred_cov = test_cov - \ | |
obs_test_cov.T.dot(np.linalg.solve(self._obs_cov, obs_test_cov)) | |
return pred_mean, pred_cov | |
n_input = 10 | |
n_pred = 100 | |
x = np.arange(n_input) | |
y = np.random.random(n_input) | |
x_pred = np.linspace(x.min() - 5, x.max() + 5, n_pred) | |
regression = GaussianProcessRegression(kernel) | |
regression.train(x, y) | |
y_pred, cov_pred = regression.predict(x_pred) | |
plt.fill_between(x_pred, y_pred + np.sqrt(cov_pred.diagonal()), | |
y_pred - np.sqrt(cov_pred.diagonal())) | |
plt.plot(x, y, 'kx') | |
plt.plot(x_pred, y_pred, 'k-') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment