Skip to content

Instantly share code, notes, and snippets.

@wupeifan
Last active June 26, 2020 04:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wupeifan/1f552c8122382286dae69a3c20dddca3 to your computer and use it in GitHub Desktop.
Save wupeifan/1f552c8122382286dae69a3c20dddca3 to your computer and use it in GitHub Desktop.
Example of DSSM+Zygote+Turing
using DifferentiableStateSpaceModels, ModelingToolkit, SparseArrays, LinearAlgebra, Parameters, Test, TimerOutputs, BenchmarkTools
using Turing, Zygote, ChainRules
Turing.setadbackend(:zygote)
# Generate fake data
H, mod_vals = Examples.rbc()
model = FirstOrderStateSpaceModel(H; mod_vals...)
p = [0.2, 0.02, 0.01]
θ_value = [0.5, 0.95]
solution, _ = solve_model!(model, θ = θ_value, p = p)
T = 200
# eps_value = [0.22, 0.01, 0.14, 0.03, 0.15, 0.21, 0.22, 0.05, 0.18]
# eps_value = reshape(eps_value, 1, 9)
eps_value = randn(1, T) * 0.01
seq = solve_sequence(model, solution, eps_value)
Q = zeros(2, 4)
Q[1, 1] = 1 # k
Q[2, 3] = 1 # c
obs = [Q * seq.u[i] for i in 1:T]
# A function for ϵ only
function seq_ϵ(ϵ)
seq = solve_sequence(model, solution, ϵ)
return seq.u
end
function ChainRules.rrule(::typeof(seq_ϵ), ϵ)
seq = solve_sequence(model, solution, ϵ)
function seq_pullback(Δu)
Δϵ = zeros(size(ϵ))
for i in 1:length(Δu)
for j in 1:length(seq.u[i])
Δϵ += seq.u_ϵ[i][j, :, :] * Δu[i][j]
end
end
return ChainRules.NO_FIELDS, Δϵ
end
return seq.u, seq_pullback
end
@model function estimate_ϵ(obs)
n_ϵ = 1
Ω ~ InverseGamma(3, 0.02)
ϵ ~ filldist(Normal(0, 0.01), n_ϵ, T)
obs_model = seq_ϵ(ϵ)
for i in 1:T
obs[i] ~ MvNormal(Q * obs_model[i], Ω)
end
end
est_model = estimate_ϵ(obs)
n_samples = 1000
chain = sample(est_model, HMC(0.05, 7), n_samples; save_state = true, progress = true)
## Both θ and ϵ
function seq_θ_ϵ(θ, ϵ)
solution, _ = solve_model!(model, θ = θ, p = p)
seq = solve_sequence(model, solution, ϵ)
return seq.u
end
function ChainRules.rrule(::typeof(seq_θ_ϵ), θ, ϵ)
solution, _ = solve_model!(model, θ = θ, p = p)
seq = solve_sequence(model, solution, ϵ)
function seq_pullback(Δu)
Δθ = zeros(size(θ))
Δϵ = zeros(size(ϵ))
for i in 1:length(Δu)
Δθ += (Δu[i]' * seq.u_θ[i])[:]
for j in 1:length(seq.u[i])
Δϵ += seq.u_ϵ[i][j, :, :] * Δu[i][j]
end
end
return ChainRules.NO_FIELDS, Δθ, Δϵ
end
return seq.u, seq_pullback
end
@model function estimate_θ_ϵ(obs)
n_ϵ = 1
α ~ Uniform(0.4, 0.6)
β ~ Uniform(0.8, 0.99)
θ = [α, β]
Ω ~ InverseGamma(3, 0.02)
ϵ ~ filldist(Normal(0, 0.01), n_ϵ, T)
obs_model = seq_θ_ϵ(θ, ϵ)
for i in 1:T
obs[i] ~ MvNormal(Q * obs_model[i], Ω)
end
end
est_model = estimate_θ_ϵ(obs)
n_samples = 1000
chain = sample(est_model, HMC(0.05, 8), n_samples; save_state = true, progress = true)
mix_sampler = Gibbs(HMC(0.001, 7, :α, :β, :Ω), PG(20, :ϵ))
chain = sample(est_model, mix_sampler, n_samples; progress = true)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment