Created
November 27, 2020 19:37
-
-
Save thomaspinder/0c9156ee6adc1c90d1409586886216f8 to your computer and use it in GitHub Desktop.
Sparse spectrum Gaussian process inmplemented in GPFlow 2.0
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gpflow | |
import tensorflow as tf | |
import matplotlib.pyplot as plt | |
from gpflow.models import GPModel, InternalDataTrainingLossMixin, GPR | |
from gpflow.kernels import SquaredExponential, Kernel | |
from gpflow.likelihoods import Gaussian | |
from gpflow.mean_functions import Zero | |
from gpflow.base import Parameter | |
from time import time | |
from gpflow.utilities import print_summary | |
from gpflow.models.model import InputData, MeanAndVariance | |
from typing import Tuple | |
import numpy as np | |
tf.random.set_seed(123) | |
np.random.seed(123) | |
class SSGP(GPModel, InternalDataTrainingLossMixin): | |
def __init__(self, data: Tuple, M: int): | |
""" | |
For the brevity of only implementing a single spectral denisty, we'll assume an RBF kernel. | |
:param data: A tuple containing X and y, both of which are 1-dimensional | |
:param M: The number of MC samples to make | |
""" | |
kernel = SquaredExponential() | |
likelihood = Gaussian() | |
mfunc = Zero() | |
super().__init__(kernel=kernel, | |
likelihood=likelihood, | |
mean_function=mfunc, | |
num_latent_gps=1) | |
self.data = data | |
self.input_dim = self.data[0].shape[1] | |
self.n_features = M | |
self.omega = Parameter(np.random.randn(self.n_features, self.input_dim)) | |
def maximum_log_likelihood_objective(self) -> tf.Tensor: | |
return self.log_marginal_likelihood() | |
def _scale_omega(self): | |
return self.omega/self.kernel.lengthscales | |
def compute_phi(self, X, w): | |
phi_inner = tf.matmul(X, w, transpose_b=True) | |
return tf.concat((tf.cos(phi_inner), tf.sin(phi_inner)), axis=1) | |
def log_marginal_likelihood(self) -> tf.Tensor: | |
""" | |
Construct the marginal log-likelihood | |
:return: | |
""" | |
X, y = self.data | |
n = X.shape[0] | |
lik_var = self.likelihood.variance | |
kern_var = self.kernel.variance | |
w = self._scale_omega() | |
# Compute φ(x)=[cos(ω1x), cos(ω2x),...,cos(ωMx), sin(ω1x),...,sin(ωMx)] | |
phi = self.compute_phi(X, w) | |
# Compute A from (7) A=ΦΦ'+I*(mσ^2)/(σ^2) | |
# A = tf.matmul(phi, phi, transpose_b=True) + tf.eye(2*self.n_features)*(self.kernel.variance.numpy()/self.m) | |
A = (kern_var / self.n_features) * tf.matmul( | |
phi, phi, transpose_a=True) + lik_var * tf.eye(self.n_features * 2, | |
dtype=tf.float64) | |
Rt = tf.linalg.cholesky(A) | |
RtiPhit = tf.linalg.triangular_solve(Rt, tf.transpose(phi)) | |
RtiPhity = tf.matmul(RtiPhit, y) | |
# Bring all the above terms together into the MLL (8) | |
term1 = (tf.reduce_sum(tf.square(y))-tf.reduce_sum(tf.square(RtiPhity))*kern_var/self.n_features)*0.5/lik_var | |
term2 = tf.reduce_sum(tf.math.log(tf.linalg.diag_part(tf.transpose(Rt))))+(n*0.5-self.n_features)*tf.math.log(lik_var)+(n*0.5*tf.cast(tf.math.log(2*np.pi), dtype=tf.float64)) | |
return -tf.reduce_sum(term1+term2) | |
def predict_f(self, | |
Xnew: InputData, | |
full_cov: bool = False, | |
full_output_cov: bool = False) -> MeanAndVariance: | |
X, y = self.data | |
nstar = Xnew.shape[0] | |
lik_var = self.likelihood.variance.numpy() | |
kern_var = self.kernel.variance.numpy() | |
w = self._scale_omega() | |
phi = self.compute_phi(X, w) | |
A = (kern_var / self.n_features) * tf.matmul( | |
phi, phi, transpose_a=True) + lik_var * tf.eye(self.n_features * 2, | |
dtype=tf.float64) | |
Rt = tf.linalg.cholesky(A) | |
RtiPhit = tf.linalg.triangular_solve(Rt, tf.transpose(phi)) | |
# Rtiphity=RtiPhit*y_tr; | |
Rtiphity = tf.matmul(RtiPhit, y) | |
# Compute sinusoidal coefficients | |
alpha = kern_var/self.n_features * tf.linalg.triangular_solve(tf.transpose(Rt), Rtiphity, lower=False) | |
# Compute phi given new inputs | |
phistar = self.compute_phi(Xnew, w) | |
# Predictive mean | |
mu = tf.matmul(phistar, alpha) | |
# Predictive variance | |
RtiPhiStarT = tf.linalg.triangular_solve(Rt, tf.transpose(phistar)) | |
PhiRStar = tf.transpose(RtiPhiStarT) | |
if full_cov: | |
# Return full covariance matrix | |
sigma = lik_var*kern_var/self.n_features * tf.matmul(PhiRStar, PhiRStar, transpose_b=True) + tf.eye(nstar)*1e-6 | |
sigma = tf.expand_dims(sigma, axis=2) | |
else: | |
# Just return the predictive variances | |
sigma = lik_var*kern_var/self.n_features * tf.reduce_sum(tf.square(PhiRStar), axis=1) | |
sigma = tf.expand_dims(sigma, axis=1) | |
return mu, sigma | |
if __name__ == '__main__': | |
def func(x): | |
return np.sin(x * 3 * 3.14) + 0.3 * np.cos(x * 9 * 3.14) + 0.5 * np.sin(x * 7 * 3.14) | |
n = 1000 | |
X = np.sort(np.random.uniform(low=-1, high=1, size=n).reshape(-1, 1), axis=0) | |
y = func(X) +0.2*np.random.randn(n, 1) | |
m = SSGP((X, y), M=50) | |
print_summary(m) | |
print("Original MLL") | |
print(m.log_marginal_likelihood()) | |
start = time() | |
opt = tf.optimizers.Adam(learning_rate=0.001) | |
logfs = [] | |
for i in range(2000): | |
opt.minimize(m.training_loss, m.trainable_variables) | |
logfs.append(-m.training_loss().numpy()) | |
end = time() - start | |
print(f"SSGP: {end}") | |
plt.plot(logfs) | |
plt.show() | |
print("Post-optimisation MLL") | |
print(m.log_marginal_likelihood()) | |
# gpr = GPR((X, y), kernel=SquaredExponential()) | |
# print_summary(gpr) | |
# print("Original MLL") | |
# print(gpr.log_marginal_likelihood()) | |
# | |
# start = time() | |
# opt = tf.optimizers.Adam(learning_rate=0.01) | |
# for i in range(500): | |
# opt.minimize(gpr.training_loss, gpr.trainable_variables) | |
# end = time() - start | |
# print(f"GPR: {end}") | |
# | |
# print("Post-optimisation MLL") | |
# print(gpr.log_marginal_likelihood()) | |
# | |
# | |
Xte = np.linspace(-1.25, 1.25, 500).reshape(-1, 1) | |
mu, sigma = m.predict_f(Xte) | |
plt.plot(X, y, 'o', alpha=0.5) | |
plt.plot(Xte, mu.numpy()) | |
mus = mu.numpy().squeeze() | |
sigmas = sigma.numpy().squeeze() | |
plt.fill_between(Xte.squeeze(), mus-1.96*sigmas, mus+1.96*sigmas, alpha=0.5) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment