Skip to content

Instantly share code, notes, and snippets.

@maximerischard
Created August 24, 2020 20:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maximerischard/1f89081be235f374217bd63dc7a496a8 to your computer and use it in GitHub Desktop.
Save maximerischard/1f89081be235f374217bd63dc7a496a8 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m environment at `~/Documents/Harvard/GaussianProcesses/Project.toml`\n",
"\u001b[32m\u001b[1mPrecompiling\u001b[22m\u001b[39m project...\n",
"┌ Info: Precompiling BenchmarkTools [6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf]\n",
"└ @ Base loading.jl:1260\n",
"┌ Info: Precompiling ScikitLearn [3646fa90-6ef7-5e7e-9f22-8aca16db6324]\n",
"└ @ Base loading.jl:1260\n"
]
}
],
"source": [
"# executing this cell will install all required julia packages\n",
"import Pkg\n",
"Pkg.activate(\".\")\n",
"Pkg.instantiate()\n",
"Pkg.precompile()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import PyPlot; plt=PyPlot\n",
"using LaTeXStrings\n",
";"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"using Distributions\n",
"import Calculus\n",
"using LinearAlgebra\n",
"using PDMats\n",
"using Profile"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"using Revise\n",
"using GaussianProcesses\n",
"using GaussianProcesses: get_params, update_mll!, set_params!"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"using Random\n",
"Random.seed!(1)\n",
"n,p = 100,1\n",
"f_star(x::Real) = abs(x-5)*cos(2*x)\n",
"σ_y = 0.8\n",
"X_distr = Uniform(-2,2)\n",
"ϵ_distr = Normal(0,σ_y)\n",
"x = sort(rand(X_distr, n))\n",
"Y = f_star.(x) .+ rand(ϵ_distr,n)\n",
"k = SEIso(0.5, 0.8)\n",
"logNoise = log(σ_y)\n",
"gp = GP(Matrix(x'), Y, MeanZero(), k, logNoise)\n",
"optimize!(gp; domean=false, kern=true, noise=true)\n",
"optim_params = get_params(gp; domean=false, kern=true, noise=true)\n",
";"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Leave-one-out"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"PyPlot.Figure(PyObject <Figure size 640x480 with 1 Axes>)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"μi,σ2i = GaussianProcesses.predict_LOO(gp)\n",
"plt.plot(x, μi)\n",
"plt.fill_between(x, μi.-sqrt.(σ2i), μi.+sqrt.(σ2i), alpha=0.3)\n",
"xx = range(-2,stop=2,length=100)\n",
"plt.plot(xx, f_star.(xx))\n",
";"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-122.58828939891164"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"GaussianProcesses.logp_LOO(gp)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-122.58828939891187"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"CV = 0.0\n",
"for i in 1:n\n",
" xV, yV = gp.x[:,i], gp.y[i]\n",
" T = [j for j in 1:n if j!=i]\n",
" xT, yT = gp.x[:,T], gp.y[T]\n",
" gpT = GPE(xT, yT, gp.mean, gp.kernel, gp.logNoise)\n",
" pred_i = predict_y(gpT, xV)\n",
" @assert isapprox(pred_i[1][1], μi[i], atol=1e-6)\n",
" @assert isapprox(pred_i[2][1], σ2i[i], atol=1e-6)\n",
" CV += logpdf(Normal(pred_i[1][1], sqrt(pred_i[2][1])), gp.y[i])\n",
"end\n",
"CV"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"true"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function target(θ)\n",
" θprev = get_params(gp.kernel)\n",
" set_params!(gp.kernel, θ)\n",
" update_mll!(gp)\n",
" CV = GaussianProcesses.logp_LOO(gp)\n",
" set_params!(gp.kernel, θprev) # put it back\n",
" return CV\n",
"end\n",
"grad_numerical = Calculus.gradient(target, get_params(k))\n",
"update_mll!(gp)\n",
"grad_analytical = GaussianProcesses.dlogpdθ_LOO(gp; noise=false, domean=false, kern=true)\n",
"all(isapprox.(grad_numerical,grad_analytical, atol=1e-6))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Folds"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"PyPlot.Figure(PyObject <Figure size 640x480 with 1 Axes>)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"folds = [1:10, 11:50, 51:80, 81:100]\n",
"μ,Σ=GaussianProcesses.predict_CVfold(gp, folds)\n",
"for (V,μVT,ΣVT) in zip(folds,μ,Σ)\n",
" plt.plot(x[V], μVT)\n",
" err = sqrt.(diag(ΣVT))\n",
" plt.fill_between(x[V], μVT.-err, μVT.+err, alpha=0.3)\n",
"end\n",
"xx = range(-2,stop=2,length=100)\n",
"plt.plot(xx, f_star.(xx), color=\"black\")\n",
";"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-125.87309965178507"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"GaussianProcesses.logp_CVfold(gp, folds)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"true"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function target(θ)\n",
" θprev = get_params(gp.kernel)\n",
" set_params!(gp.kernel, θ)\n",
" update_mll!(gp)\n",
" CV = GaussianProcesses.logp_CVfold(gp, folds)\n",
" set_params!(gp.kernel, θprev) # put it back\n",
" return CV\n",
"end\n",
"grad_numerical = Calculus.gradient(target, get_params(k))\n",
"update_mll!(gp)\n",
"grad_analytical = GaussianProcesses.dlogpdθ_CVfold(gp, folds; noise=false, kern=true, domean=false)\n",
"all(isapprox.(grad_numerical,grad_analytical, atol=1e-6))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3-element Array{Float64,1}:\n",
" -0.22892239541138876\n",
" -0.08158749657391891\n",
" 1.5412099504509285"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_params(gp; noise=true, kern=true, domean=false)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"true"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function target(θ)\n",
" θprev = get_params(gp; noise=true, kern=true, domean=false)\n",
" set_params!(gp, θ; noise=true, kern=true, domean=false)\n",
" update_mll!(gp; noise=true, kern=true, domean=false)\n",
" CV = GaussianProcesses.logp_CVfold(gp, folds)\n",
" set_params!(gp, θprev; noise=true, kern=true, domean=false) # put it back\n",
" return CV\n",
"end\n",
"grad_numerical = Calculus.gradient(target, get_params(gp; noise=true, kern=true, domean=false))\n",
"update_mll!(gp)\n",
"grad_analytical = GaussianProcesses.dlogpdθ_CVfold(gp, folds; noise=true, kern=true, domean=false)\n",
"all(isapprox.(grad_numerical,grad_analytical, atol=1e-6))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Performance"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"using Random\n",
"Random.seed!(1)\n",
"n,p = 3000,1 # bigger\n",
"f_star(x::Real) = abs(x-5)*cos(2*x)\n",
"σ_y = 0.8\n",
"X_distr = Uniform(-2,2)\n",
"ϵ_distr = Normal(0,σ_y)\n",
"x = sort(rand(X_distr, n))\n",
"Y = f_star.(x) .+ rand(ϵ_distr,n)\n",
"k = SEIso(0.5, 0.8)\n",
"logNoise = log(σ_y)\n",
"gp_big = GP(Matrix(x'), Y, MeanZero(), k, logNoise)\n",
"optimize!(gp; domean=false, kern=true, noise=true)\n",
";"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 0.782424 seconds (28.80 k allocations: 1.592 MiB)\n"
]
}
],
"source": [
"precomp = GaussianProcesses.init_precompute(gp_big)\n",
"@time GaussianProcesses.update_mll_and_dmll!(gp_big, precomp; domean=false, kern=true, noise=false)\n",
";"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 3.518531 seconds (217 allocations: 561.574 MiB, 5.84% gc time)\n"
]
}
],
"source": [
"folds_big = [1:500, 501:1500, 1501:2200, 2201:3000]\n",
"GaussianProcesses.dlogpdθ_CVfold(gp_big, folds_big; noise=true, domean=true, kern=true)\n",
"@time GaussianProcesses.dlogpdθ_CVfold(gp_big, folds_big; noise=true, domean=true, kern=true)\n",
";"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"3-element Array{Float64,1}:\n",
" -7.203671608482038\n",
" -4070.9279980851143\n",
" 866.1506727441109"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Profile.clear()\n",
"@profile GaussianProcesses.dlogpdθ_CVfold(gp_big, folds_big; noise=true, domean=true, kern=true)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Optimization"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Main.GPCV"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"module GPCV\n",
" using GaussianProcesses: GaussianProcesses, get_params_kwargs, get_params, set_params!,\n",
" GPBase, LBFGS, bounds, Optim, optimize, OnceDifferentiable,\n",
" LinearAlgebra,\n",
" update_target!, \n",
" dlogpdθ_LOO, logp_LOO\n",
"# using Optim: Optim, optimize\n",
" \n",
" \"\"\"\n",
" optimize!(gp::GPBase, args...; kwargs...)\n",
"\n",
" Optimise the hyperparameters of Gaussian process `gp` based on type II maximum likelihood estimation. This function performs gradient based optimisation using the Optim pacakge to which the user is referred to for further details.\n",
"\n",
" # Keyword arguments:\n",
" * `domean::Bool`: Mean function hyperparameters should be optmized\n",
" * `kern::Bool`: Kernel function hyperparameters should be optmized\n",
" * `noise::Bool`: Observation noise hyperparameter should be optimized (GPE only)\n",
" * `lik::Bool`: Likelihood hyperparameters should be optimized (GPA only)\n",
" * `meanbounds`: [lowerbounds, upperbounds] for the mean hyperparameters\n",
" * `kernbounds`: [lowerbounds, upperbounds] for the kernel hyperparameters\n",
" * `noisebounds`: [lowerbound, upperbound] for the noise hyperparameter\n",
" * `args/kwargs`: Arguments and keyword arguments for the optimize function from the Optim package https://julianlsolvers.github.io/Optim.jl/stable/#user/config/\n",
" \"\"\"\n",
" function optimize!(gp::GPBase, method = LBFGS(), args...; domean::Bool = true, kern::Bool = true,\n",
" noise::Bool = true, lik::Bool = true,\n",
" meanbounds = nothing, kernbounds = nothing,\n",
" noisebounds = nothing, likbounds = nothing, kwargs...)\n",
" params_kwargs = get_params_kwargs(gp; domean=domean, kern=kern, noise=noise, lik=lik)\n",
" # println(params_kwargs)\n",
" func = get_optim_target(gp; params_kwargs...)\n",
" init = get_params(gp; params_kwargs...) # Initial hyperparameter values\n",
" if meanbounds == kernbounds == noisebounds == likbounds == nothing\n",
" results = optimize(func, init, method, args...; kwargs...) # Run optimizer\n",
" else\n",
" lb, ub = bounds(gp, noisebounds, meanbounds, kernbounds, likbounds;\n",
" domean = domean, kern = kern, noise = noise, lik = lik)\n",
" results = optimize(func.f, func.df, lb, ub, init, Fminbox(method), args...)\n",
" end\n",
" set_params!(gp, Optim.minimizer(results); params_kwargs...)\n",
" update_target!(gp)\n",
" return results\n",
" end\n",
"\n",
" function get_optim_target(gp::GPBase; params_kwargs...)\n",
" function ltarget(hyp::AbstractVector)\n",
" prev = get_params(gp; params_kwargs...)\n",
" try\n",
" set_params!(gp, hyp; params_kwargs...)\n",
" update_target!(gp)\n",
" logp = logp_LOO(gp)\n",
" return -logp\n",
" catch err\n",
" # reset parameters to remove any NaNs\n",
" set_params!(gp, prev; params_kwargs...)\n",
"\n",
" if !all(isfinite.(hyp))\n",
" println(err)\n",
" return Inf\n",
" elseif isa(err, ArgumentError)\n",
" println(err)\n",
" return Inf\n",
" elseif isa(err, LinearAlgebra.PosDefException)\n",
" println(err)\n",
" return Inf\n",
" else\n",
" throw(err)\n",
" end\n",
" end\n",
" end\n",
"\n",
" function ltarget_and_dltarget!(grad::AbstractVector, hyp::AbstractVector)\n",
" prev = get_params(gp; params_kwargs...)\n",
" try\n",
" set_params!(gp, hyp; params_kwargs...)\n",
" update_target!(gp)\n",
" dlogpdθ = dlogpdθ_LOO(gp; params_kwargs...)\n",
" grad[:] = -dlogpdθ\n",
" logp = logp_LOO(gp)\n",
" return -logp\n",
" catch err\n",
" # reset parameters to remove any NaNs\n",
" set_params!(gp, prev; params_kwargs...)\n",
" if !all(isfinite.(hyp))\n",
" println(err)\n",
" return Inf\n",
" elseif isa(err, ArgumentError)\n",
" println(err)\n",
" return Inf\n",
" elseif isa(err, LinearAlgebra.PosDefException)\n",
" println(err)\n",
" return Inf\n",
" else\n",
" throw(err)\n",
" end\n",
" end\n",
" end\n",
"\n",
" function dltarget!(grad::AbstractVector, hyp::AbstractVector)\n",
" prev = get_params(gp; params_kwargs...)\n",
" try\n",
" set_params!(gp, hyp; params_kwargs...)\n",
" update_target!(gp)\n",
" dlogpdθ = dlogpdθ_LOO(gp; params_kwargs...)\n",
" grad[:] = -dlogpdθ\n",
" return -dlogpdθ\n",
" catch err\n",
" # reset parameters to remove any NaNs\n",
" set_params!(gp, prev; params_kwargs...)\n",
" if !all(isfinite.(hyp))\n",
" println(err)\n",
" return Inf\n",
" elseif isa(err, ArgumentError)\n",
" println(err)\n",
" return Inf\n",
" elseif isa(err, LinearAlgebra.PosDefException)\n",
" println(err)\n",
" return Inf\n",
" else\n",
" throw(err)\n",
" end\n",
" end\n",
" end\n",
"\n",
" xinit = get_params(gp; params_kwargs...)\n",
" func = OnceDifferentiable(ltarget, dltarget!, ltarget_and_dltarget!, xinit)\n",
" return func\n",
" end\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3-element Array{Float64,1}:\n",
" -0.2305426001964157\n",
" 0.14212460173583263\n",
" 1.790568654194302"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"GPCV.optimize!(gp; domean=false, kern=true, noise=true)\n",
"optimCV_params = get_params(gp; domean=false, kern=true, noise=true)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"optim_params = [-0.22892239541138876, -0.08158749657391891, 1.5412099504509285]\n",
"optimCV_params = [-0.2305426001964157, 0.14212460173583263, 1.790568654194302]\n"
]
}
],
"source": [
"@show optim_params\n",
"@show optimCV_params\n",
";"
]
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Julia 1.4.0",
"language": "julia",
"name": "julia-1.4"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.4.0"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": true
}
},
"nbformat": 4,
"nbformat_minor": 2
}
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ElasticPDMats = "2904ab23-551e-5aed-883f-487f97af5226"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GaussianProcesses = "891a1506-143c-57d2-908e-e1f8e92e6de9"
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
ScikitLearn = "3646fa90-6ef7-5e7e-9f22-8aca16db6324"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Traceur = "37b6cedf-1f77-55f8-9503-c64b63398394"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment