Skip to content

Instantly share code, notes, and snippets.

@alfredjmduncan
Created January 9, 2021 17:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alfredjmduncan/099e1493118489be6b632fdfd8213fd3 to your computer and use it in GitHub Desktop.
Save alfredjmduncan/099e1493118489be6b632fdfd8213fd3 to your computer and use it in GitHub Desktop.
Turing example with implicit functions and forwarddiff
using NLsolve
using ChainRules
using ForwardDiff
using ComponentArrays
using Distributions
using Turing
using StatsPlots
# Γ handles all of the data in the model, in a stacked vector
# (using the Component Array type). Γ has the following structure:
# Γ
# ├── y unknown variables
# ├── Ω data
# │ ├── x input data (state, known variables)
# │ └── θ parameters
# │ └── y initial values for the solver
# The combination of `model_residuals!` and Γ.Ω.θ define the model,
# Γ.Ω.y provides initial values for the solver,
# Γ.Ω.x is the known values in the model, which could include exogenous
# states or matched observables.
function model_residuals!(F, Γ)
F[1] = exp(Γ.y.y₁)-Γ.Ω.θ₂*exp(Γ.y.y₂)+Γ.Ω.θ₁
F[2] = Γ.y.y₁-Γ.y.y₂+Γ.Ω.θ₃
end
# Unkowns as a function of input data
function ŷ(Ω::AbstractVector{T}) where T
f!(F,y) = model_residuals!(F,ComponentArray(y=y,Ω=Ω))
y = nlsolve(f!,ComponentVector{T}(y₁=Ω.y₁,y₂=Ω.y₂),autodiff=:forward).zero
return y
end
# Forward rule (see ChainRules.jl)
function frule((_,ΔΩ),::typeof(ŷ),Ω::AbstractVector{<:Real})
y = ŷ(Ω)
f!(F,y_) = model_residuals!(F,ComponentArray(y=y_,Ω=Ω))
j = ForwardDiff.jacobian(f!,zeros(2),ComponentArray(y=y,Ω=Ω))
ny,nΩ = length(y),length(Ω)
∂y = -j[1:ny,1:ny]\j[1:ny,ny+1:ny+nΩ]
return y,∂y
end
# Turing model for estimation
@model function gtest(data,θ₁)
# priors - model parameters
θ₂ ~ Normal(3.0,1.0)
θ₃ ~ Normal(2.0,0.6)
# prior - measurement error
s ~ InverseGamma(2,0.5)
for i in 1:size(data)[2]
data[:,i] ~ MvNormal(
ŷ(ComponentArray(θ₁=θ₁[i],θ₂=θ₂,θ₃=θ₃,y₁=0.0,y₂=1.0)),
sqrt(s)
)
end
end
# Generate test data
θ₁ = rand(Normal(7.0,2.0),100)
θ₂, θ₃ = 2.5,1.0
data = Array{Float64,2}(undef,2,100)
for i in 1:100
data[:,i] =
ŷ(ComponentArray(θ₁=θ₁[i],θ₂=θ₂,θ₃=θ₃,y₁=0.0,y₂=1.0)) +
rand(Normal(0,0.1),2)
end
# Estimation
c = sample(gtest(data,θ₁),NUTS(100,0.65),1000)
# Summarise and plot results
describe(c)
plot(c)
savefig("gtest-plot.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment