Skip to content

Instantly share code, notes, and snippets.

@mschauer
Last active April 3, 2021 10:54
Show Gist options
  • Save mschauer/6841531d0370d46c1d9ad6e23feda489 to your computer and use it in GitHub Desktop.
Save mschauer/6841531d0370d46c1d9ad6e23feda489 to your computer and use it in GitHub Desktop.
Linear regression, p = 5, n = 10_000_000, with subsampling and approx ML estimate as control variate
using ZigZagBoomerang
using StaticArrays
using LinearAlgebra
using SparseArrays
using Random
using Test
using Statistics
Random.seed!(2)
using StaticArrays
# scale ~ Exponential(λ)
# coefs ~ Normal()
# preds ~ Normal(dot(x, coefs), scale)
#λ = 1.0
#σ = randexp()*λ
β = @SVector randn(5)
n = 10_000_000
const d = 5
X = randn(typeof(β), n)
y = dot.(X, Ref(β)) + randn(n)
function ∇ϕkhat(β, samples, X, y, μ, bias)
s = bias
for _ in 1:samples
i = rand(1:length(y))
s += length(y)/samples*(-X[i]*(y[i] - dot(X[i], β))) # likelihood
s -= length(y)/samples*(-X[i]*(y[i] - dot(X[i], μ))) # control
end
s
end
prior(x) = x # Gaussian prior
function ∇ϕ!(x_, x::T, args...) where {T}
prior(x) + ∇ϕkhat(x, args...)::T
end
∇ϕfull(μ, X, y) = @inbounds sum(-X[i]*(y[i] - dot(X[i], μ)) for i in eachindex(y))
# Look at a fraction of the data:
c = 50000.0
X_ = reinterpret(reshape, Float64, X[1:end÷200])'
μ = SVector{5,Float64}(X_\(y[1:end÷200]))
# one look at the full gradient
bias = ∇ϕfull(μ, X, y) # sum(-X[i]*(y[i] - dot(X[i], μ)) for i in eachindex(y))
t0 = 0.0
x0 = μ
θ0 = @SVector randn(Float64, 5)
Γ = SMatrix{5, 5, Float64, 25}(((X_'*X_)))
Γ = SMatrix{5, 5, Float64, 25}(Diagonal(1e7*ones(5)))
T = 2000.0 * 1/sqrt(n)
BP = BouncyParticle(Γ, μ, T/100)
samples = 20
trace, (tT, xT, θT), (acc, num), _ = pdmp(∇ϕ!, t0, x0, θ0, T/20, c, BP, samples, X, y, μ, bias, adapt=true)
@time trace, (tT, xT, θT), (acc, num), c = pdmp(∇ϕ!, t0, xT, θ0, T, c, BP, samples, X, y, μ, bias, adapt=true)
ts, xs = ZigZagBoomerang.sep(trace)
xsd = last.(collect(discretize(trace, 1/sqrt(n))))
using Makie
p1 = lines(getindex.(xs,1), getindex.(xs,2), linewidth=0.4, color=ts)
scatter!(getindex.(xsd,1), getindex.(xsd,2))
scatter!([β[1]], [β[2]], color=:red) # truth
scatter!([μ[1]], [μ[2]], color=:blue) # approximate mode
m = reinterpret(reshape, Float64, X)'\y
scatter!([m[1]], [m[2]], color=:orange) # analytic posterior mean
p1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment