Skip to content

Instantly share code, notes, and snippets.

@gpantalos
Created April 13, 2021 08:46
Show Gist options
  • Save gpantalos/f046339124378359f5fff7e59e4ac94a to your computer and use it in GitHub Desktop.
Save gpantalos/f046339124378359f5fff7e59e4ac94a to your computer and use it in GitHub Desktop.
Gaussian Process in NumPy
import matplotlib.pyplot as plt
import numpy as np
inv = np.linalg.inv
class GaussianProcess:
def __init__(self, kernel, observed_index_points, observed_values,
mean_function=None, stddev=1.0, noise_var=0.0):
self.observed_index_points = self._handle(observed_index_points)
self.observed_values = self._handle(observed_values)
# kernel params
self.stddev = stddev
if mean_function is None:
self.mean_function = lambda x: np.zeros_like(x)
else:
self.mean_function = mean_function
self._kernel = kernel
# observation noise
self.noise_var = noise_var + 1e-6
# index points
a, b = np.min(observed_index_points), np.max(observed_index_points)
self.index_points = np.linspace(a, b, 250)
@staticmethod
def _handle(x):
x = np.array(x)
if x.ndim == 1:
x = x[..., None]
return x
def prior(self, index_points):
index_points = self._handle(index_points)
mu_prior = self.mean_function(index_points)
k_prior = self._kernel(index_points, index_points) * self.stddev
return mu_prior, k_prior
def posterior(self, index_points):
index_points = self._handle(index_points)
observed_index_points = self.observed_index_points
observed_values = self.observed_values - self.mean_function(observed_index_points)
k = self._kernel(index_points, index_points) * self.stddev
ka = self._kernel(observed_index_points, index_points) * self.stddev
kaa = self._kernel(observed_index_points, observed_index_points) * self.stddev
kaa += self.noise_var * np.eye(len(kaa))
mean = (ka.T @ inv(kaa)) @ observed_values + self.mean_function(index_points)
cov = k - ka.T @ inv(kaa) @ ka
return mean, cov
def plot(self, distribution, plot_samples):
if distribution == 'prior':
mean, cov = gp.prior(self.index_points)
elif distribution == 'posterior':
mean, cov = gp.posterior(self.index_points)
else:
raise NotImplemented
mean = mean[..., 0]
stddev = np.sqrt(np.diag(cov))
if plot_samples:
for _ in range(5):
_sample = np.random.multivariate_normal(mean, cov)
plt.plot(self.index_points, _sample, c='r', lw=0.5, alpha=0.5)
plt.plot(self.index_points, mean)
plt.fill_between(self.index_points, mean - stddev, mean + stddev, alpha=0.2)
plt.scatter(self.observed_index_points, self.observed_values, marker='x', s=5)
plt.show()
def rbf(x1, x2, l=0.5):
"""c.f. Gaussian Processes for Machine Learning p. 84"""
r = np.abs(x1 - x2.T) ** 2
return np.exp(- r / (2 * l ** 2))
def f(x): return np.sin(10 * x) * np.exp(-x ** 2)
if __name__ == '__main__':
observed_index_points_ = np.random.uniform(-1, 1, 50)
observed_values_ = f(observed_index_points_)
gp = GaussianProcess(rbf, observed_index_points_, observed_values_)
gp.plot("prior", True)
gp.plot("posterior", True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment