Skip to content

Instantly share code, notes, and snippets.

@pat-alt
Last active February 21, 2022 11:36
Show Gist options
  • Save pat-alt/a2303d1ad0ad1aed9fe49aee991ed857 to your computer and use it in GitHub Desktop.
Save pat-alt/a2303d1ad0ad1aed9fe49aee991ed857 to your computer and use it in GitHub Desktop.
Laplace approximation for effortless Bayesian deep learning - logistic regression.
# Import libraries.
using Flux, Plots, Random, PlotThemes, Statistics, BayesLaplace
theme(:wong)
# Toy data:
xs, y = toy_data_linear(100)
X = hcat(xs...); # bring into tabular format
data = zip(xs,y)
# Neural network:
nn = Chain(Dense(2,1))
λ = 0.5
sqnorm(x) = sum(abs2, x)
weight_regularization(λ=λ) = 1/2 * λ^2 * sum(sqnorm, Flux.params(nn))
loss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y) + weight_regularization()
# Training:
using Flux.Optimise: update!, ADAM
opt = ADAM()
epochs = 50
for epoch = 1:epochs
for d in data
gs = gradient(params(nn)) do
l = loss(d...)
end
update!(opt, params(nn), gs)
end
end
# Laplace approximation:
la = laplace(nn, λ=λ)
fit!(la, data)
p_plugin = plot_contour(X',y,la;title="Plugin",type=:plugin);
p_laplace = plot_contour(X',y,la;title="Laplace")
# Plot the posterior distribution with a contour plot.
plot(p_plugin, p_laplace, layout=(1,2), size=(1000,400))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment