Created
January 9, 2021 17:12
-
-
Save alfredjmduncan/099e1493118489be6b632fdfd8213fd3 to your computer and use it in GitHub Desktop.
Turing example with implicit functions and forwarddiff
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using NLsolve | |
using ChainRules | |
using ForwardDiff | |
using ComponentArrays | |
using Distributions | |
using Turing | |
using StatsPlots | |
# Γ handles all of the data in the model, in a stacked vector | |
# (using the Component Array type). Γ has the following structure: | |
# Γ | |
# ├── y unknown variables | |
# ├── Ω data | |
# │ ├── x input data (state, known variables) | |
# │ └── θ parameters | |
# │ └── y initial values for the solver | |
# The combination of `model_residuals!` and Γ.Ω.θ define the model, | |
# Γ.Ω.y provides initial values for the solver, | |
# Γ.Ω.x is the known values in the model, which could include exogenous | |
# states or matched observables. | |
function model_residuals!(F, Γ) | |
F[1] = exp(Γ.y.y₁)-Γ.Ω.θ₂*exp(Γ.y.y₂)+Γ.Ω.θ₁ | |
F[2] = Γ.y.y₁-Γ.y.y₂+Γ.Ω.θ₃ | |
end | |
# Unkowns as a function of input data | |
function ŷ(Ω::AbstractVector{T}) where T | |
f!(F,y) = model_residuals!(F,ComponentArray(y=y,Ω=Ω)) | |
y = nlsolve(f!,ComponentVector{T}(y₁=Ω.y₁,y₂=Ω.y₂),autodiff=:forward).zero | |
return y | |
end | |
# Forward rule (see ChainRules.jl) | |
function frule((_,ΔΩ),::typeof(ŷ),Ω::AbstractVector{<:Real}) | |
y = ŷ(Ω) | |
f!(F,y_) = model_residuals!(F,ComponentArray(y=y_,Ω=Ω)) | |
j = ForwardDiff.jacobian(f!,zeros(2),ComponentArray(y=y,Ω=Ω)) | |
ny,nΩ = length(y),length(Ω) | |
∂y = -j[1:ny,1:ny]\j[1:ny,ny+1:ny+nΩ] | |
return y,∂y | |
end | |
# Turing model for estimation | |
@model function gtest(data,θ₁) | |
# priors - model parameters | |
θ₂ ~ Normal(3.0,1.0) | |
θ₃ ~ Normal(2.0,0.6) | |
# prior - measurement error | |
s ~ InverseGamma(2,0.5) | |
for i in 1:size(data)[2] | |
data[:,i] ~ MvNormal( | |
ŷ(ComponentArray(θ₁=θ₁[i],θ₂=θ₂,θ₃=θ₃,y₁=0.0,y₂=1.0)), | |
sqrt(s) | |
) | |
end | |
end | |
# Generate test data | |
θ₁ = rand(Normal(7.0,2.0),100) | |
θ₂, θ₃ = 2.5,1.0 | |
data = Array{Float64,2}(undef,2,100) | |
for i in 1:100 | |
data[:,i] = | |
ŷ(ComponentArray(θ₁=θ₁[i],θ₂=θ₂,θ₃=θ₃,y₁=0.0,y₂=1.0)) + | |
rand(Normal(0,0.1),2) | |
end | |
# Estimation | |
c = sample(gtest(data,θ₁),NUTS(100,0.65),1000) | |
# Summarise and plot results | |
describe(c) | |
plot(c) | |
savefig("gtest-plot.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment