Skip to content

Instantly share code, notes, and snippets.

@Kaixhin
Last active August 26, 2016 23:14
Show Gist options
  • Save Kaixhin/4035df9462491770ec0ecb4b6f4fe331 to your computer and use it in GitHub Desktop.
Save Kaixhin/4035df9462491770ec0ecb4b6f4fe331 to your computer and use it in GitHub Desktop.
Gaussian Processes for Dummies
--[[
-- Gaussian Processes for Dummies
-- https://katbailey.github.io/post/gaussian-processes-for-dummies/
-- Note 1: The Cholesky decomposition requires positive-definite matrices, hence the addition of a small value to the diagonal (prevents zeros along the diagonal)
-- Note 2: This can also be thought of as adding a little noise to the observations
--]]
local gnuplot = require 'gnuplot'
-- Test data
local n = 50
local Xtest = torch.linspace(-5, 5, n):view(-1, 1)
-- Define the kernel function (squared exponential kernel)
local kernel = function(a, b, param)
local sqdist = torch.pow(torch.repeatTensor(a, 1, b:size(1)) - torch.repeatTensor(b:t(), a:size(1), 1), 2)
return torch.exp(-0.5 * (1 / param) * sqdist)
end
local param = 0.1
local K_ss = kernel(Xtest, Xtest, param)
-- Get Cholesky decomposition (square root) of the covariance matrix
local L = torch.potrf(K_ss + 1e-15 * torch.eye(n), 'L')
-- Sample 3 sets of standard normals for our test points and multiply them by the square root of the covariance matrix
local f_prior = torch.randn(n, 3)
for s = 1, 3 do
f_prior[{{}, {s}}] = L * f_prior[{{}, {s}}]
end
-- Plot the 3 sampled functions
gnuplot.figure()
gnuplot.plot({'', Xtest:squeeze(), f_prior[{{}, {1}}], '-'}, {'', Xtest:squeeze(), f_prior[{{}, {2}}], '-'}, {'', Xtest:squeeze(), f_prior[{{}, {3}}], '-'})
gnuplot.axis({-5, 5, -3, 3})
gnuplot.title('Three samples from the GP prior')
-- Noiseless training data
local Xtrain = torch.Tensor({-4, -3, -2, -1, 1}):view(5, 1)
local ytrain = torch.sin(Xtrain)
-- Apply the kernel function to our training points
local K = kernel(Xtrain, Xtrain, param)
L = torch.potrf(K + 1e-15 * torch.eye(Xtrain:size(1)), 'L')
-- Compute the mean at our test points
local K_s = kernel(Xtrain, Xtest, param)
local Lk = torch.gesv(K_s, L)
local mu = Lk:t() * torch.gesv(ytrain, L)
-- Compute the standard deviation so we can plot it
local s2 = torch.diag(K_ss) - torch.sum(torch.pow(Lk, 2), 1)
local stdv = torch.sqrt(s2)
-- Draw samples from the posterior at our test points
L = torch.potrf(K_ss + 1e-15 * torch.eye(n) - Lk:t() * Lk)
local f_post = torch.randn(n, 3) -- Extra variable for Torch as broadcasting not supported
for s = 1, 3 do
f_post[{{}, {s}}] = L * f_post[{{}, {s}}]
end
f_post = torch.repeatTensor(mu, 1, 3) + f_post
Xtrain = Xtrain:squeeze()
Xtest = Xtest:squeeze()
local muLower, muUpper = mu - 2 * stdv, mu + 2 * stdv
local yy = torch.cat({Xtest, muLower, muUpper}, 2) -- 2 standard deviation bars
gnuplot.figure()
gnuplot.plot({yy, 'with filledcurves fill transparent solid 0.5'}, {'GP posterior mean', Xtest, mu, '-'}, {'', Xtrain:squeeze(), ytrain, '+'}, {'', Xtest, f_post[{{}, {1}}], '-'}, {'', Xtest, f_post[{{}, {2}}], '-'}, {'', Xtest, f_post[{{}, {3}}], '-'})
gnuplot.axis({-5, 5, -3, 3})
gnuplot.title('Three samples from the GP posterior')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment