Skip to content

Instantly share code, notes, and snippets.

@slwu89
Last active December 12, 2022 20:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save slwu89/136ffe03913883acace3e5378fafce89 to your computer and use it in GitHub Desktop.
Save slwu89/136ffe03913883acace3e5378fafce89 to your computer and use it in GitHub Desktop.
the Simon Wood urchin example in Julia; using the Laplace approximation to fit models with random effects, needs Julia version >= 1.8 for typed globals (https://julialang.org/blog/2022/08/julia-1.8-highlights/#typed_globals)
using Distributions
using ForwardDiff
using Optim
using LineSearches # https://github.com/JuliaNLSolvers/Optim.jl/issues/713
# logdet
using LinearAlgebra
# get the data
using CSV
using HTTP
using DataFrames
http_response = HTTP.request("GET", "https://www.maths.ed.ac.uk/~swood34/data/urchin-vol.txt")
uv = CSV.read(IOBuffer(http_response.body), DataFrame, header=["row", "age", "vol"], skipto=2)
# the biological model
function urchin_volumes(log_ω, log_g::AbstractArray, log_p::AbstractArray, a::AbstractArray)
ω = exp(log_ω)
p = exp.(log_p)
g = exp.(log_g)
am = @. log(p/(g*ω))/g
μ = [a[i] < am[i] ? ω * exp(g[i]*a[i]) : p[i]/g[i] + p[i]*(a[i]-am[i]) for i in eachindex(a)]
return μ
end
# joint log-density of likelihood and random effects
function urchin_jll(y::AbstractArray, θ::AbstractArray, log_g::AbstractArray, log_p::AbstractArray, a::AbstractArray)
log_ω, μg, log_σg, μp, log_σp, log_σ = θ
μ = urchin_volumes(log_ω, log_g, log_p, a) # predicted values
data_ll = sum(logpdf.(Normal.(sqrt.(μ), exp(log_σ)), sqrt.(y))) # contribution from data
g_ll = loglikelihood(Normal(μg, exp(log_σg)), log_g) # contribution from random effect g
p_ll = loglikelihood(Normal(μp, exp(log_σp)), log_p) # contribution from random effect p
return data_ll + g_ll + p_ll
end
# return Laplace-approximated log likelihood
# in the future, we'd want this as a callable struct, and to allow `b` to be sampled/set flexibly each
# iteration, including using the optimial values from previous rounds of optimization on \theta
# following: https://github.com/ElOceanografo/MarginalLogDensities.jl/blob/master/src/MarginalLogDensities.jl#L229
function urchin_laplace_ll(y::AbstractArray, a::AbstractArray, θ::AbstractArray)
n = length(y)
b = [fill(θ[2],n); fill(θ[4],n)] # should sample from the random effects, or make this more generic
nb = length(b)
f = (b) -> -urchin_jll(y, θ, b[1:n], b[n+1:end], a)
mle = optimize(f, b, LBFGS(linesearch=LineSearches.BackTracking()); autodiff = :forward)
H0 = ForwardDiff.hessian(x -> -urchin_jll(y, θ, x[1:n], x[n+1:end], a), mle.minimizer)
return -mle.minimum + 0.5 * (log((2π)^nb) - logdet(H0)) # ll
end
# testing with actual data
y::Vector{Float64} = uv[:,:vol]
a::Vector{Int64} = uv[:,:age]
log_ω = -4.0
μg = -0.2
log_σg = log(0.1)
μp = 0.2
log_σp = log(0.1)
log_σ = log(0.5)
θ0::Vector{Float64} = [log_ω, μg, log_σg, μp, log_σp, log_σ]
urchin_laplce_nll = (θ) -> -urchin_laplace_ll(y,a,θ)
fit = optimize(
urchin_laplce_nll,
θ0,
LBFGS(linesearch=LineSearches.BackTracking()),
Optim.Options(
f_tol = 1e-6,
g_tol = 1e-6,
iterations = 30,
show_trace = true
)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment