-
-
Save Vaibhavdixit02/3e1cd7754622556a6f768ad099c877a3 to your computer and use it in GitHub Desktop.
first attempt for https://github.com/JuliaDiffEq/DiffEqBayes.jl/issues/16
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# DiffEq tools | |
using DiffEqBayes, OrdinaryDiffEq, ParameterizedFunctions, RecursiveArrayTools | |
# clone from github.com/tpapp | |
using DynamicHMC, MCMCDiagnostics, DiffWrappers, ContinuousTransformations | |
using Parameters, Distributions, Optim | |
f1 = @ode_def_nohes LotkaVolterraTest1 begin | |
dx = a*x - b*x*y | |
dy = -c*y + d*x*y | |
end a=>1.5 b=1.0 c=3.0 d=1.0 | |
u0 = [1.0,1.0] | |
tspan = (0.0,10.0) | |
prob1 = ODEProblem(f1,u0,tspan) | |
# Generate data | |
σ = 0.01 # noise, fixed for now | |
t = collect(linspace(1,10,10)) # observation times | |
sol = solve(prob1,Tsit5()) | |
randomized = VectorOfArray([(sol(t[i]) + σ * randn(2)) for i in 1:length(t)]) | |
data = convert(Array,randomized) | |
struct LotkaVolterraPosterior{Problem, Data, A_Prior, ObservationTimes, ErrorDist} | |
problem::Problem | |
data::Data | |
a_prior::A_Prior | |
t::ObservationTimes | |
ϵ_dist::ErrorDist | |
end | |
function (P::LotkaVolterraPosterior)(θ) | |
@unpack problem, data, a_prior, t, ϵ_dist = P | |
a = θ[1] | |
try | |
prob = problem_new_parameters(problem, a) | |
sol = solve(prob, Tsit5()) | |
ℓ = sum(sum(logpdf.(ϵ_dist, sol(t) .- data[:, i])) | |
for (i, t) in enumerate(t)) | |
catch | |
ℓ = -Inf | |
end | |
if !isfinite(ℓ) && (ℓ ≠ -Inf) | |
ℓ = -Inf # protect against NaN etc, is it needed? | |
end | |
ℓ + logpdf(a_prior, a) | |
end | |
P = LotkaVolterraPosterior(prob1, data, Normal(1.5, 1), t, Normal(0.0, σ)) | |
parameter_transformation = TransformationTuple((bridge(ℝ, ℝ⁺, ))) # assuming a > 0 | |
PT = TransformLogLikelihood(P, parameter_transformation) | |
PTG = ForwardGradientWrapper(PT, zeros(1)); | |
# NOTE: starting from correct parameter is important, otherwise stepsize | |
# adaptation is not handled well. would probably maximize PT in a real-life | |
# setting. | |
PO = OnceDifferentiable(x -> -P(x), [2.0]) | |
a₀ = Optim.minimizer(optimize(a -> -P([a]), 0, 10)) | |
sample, _ = NUTS_init_tune_mcmc(PTG, | |
inverse(parameter_transformation, (a₀, )), | |
1000) | |
posterior = ungrouping_map(Vector, get_transformation(PT) ∘ get_position, sample) | |
a, = posterior | |
mean(a) # 1.49983 | |
std(a) # 0.0002 | |
# visualize | |
using Plots; pgfplots() | |
a_grid = linspace(0, 2, 1000) | |
ℓ_grid = map(a -> P([a]), a_grid) | |
plot(a_grid, ℓ_grid, xlab="a", ylab = "log density", label = false) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment