Created
August 13, 2019 08:33
-
-
Save tpapp/569bdef475c0a44c46d5c9e6d5e9e493 to your computer and use it in GitHub Desktop.
MWE for DiffEq AD mismatch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
##### | |
##### Self-contained example to dissect broken tests in | |
##### https://github.com/JuliaDiffEq/DiffEqBayes.jl/blob/e033307768892e2f7242ae0aab3e09ec4819c11b/test/dynamicHMC.jl | |
##### | |
##### NOTE: you need DynamicHMC#master and LogDensityProblems#0.9.x for this to work | |
#### | |
#### part that just relies on the DiffEq ecosystem | |
#### | |
using OrdinaryDiffEq, ParameterizedFunctions, RecursiveArrayTools, Parameters, | |
Distributions, Random | |
struct DynamicHMCPosterior{TA,TP,TL,TR,TK} | |
alg::TA | |
problem::TP | |
likelihood::TL | |
priors::TR | |
kwargs::TK | |
end | |
function (P::DynamicHMCPosterior)(θ) | |
@unpack a = θ | |
@unpack alg, problem, likelihood, priors, kwargs = P | |
prob = remake(problem,u0 = convert.(eltype(a), problem.u0), p = a) | |
sol = solve(prob, alg; kwargs...) | |
if any((s.retcode != :Success for s in sol)) && any((s.retcode != :Terminated for s in sol)) | |
return -Inf | |
end | |
likelihood(sol) + mapreduce(logpdf, +, priors, θ) | |
end | |
function data_log_likelihood(solution, data, t, σ) | |
sum(sum(logpdf.(Normal(0.0, σ), solution(t) .- data[:, i])) for (i, t) in enumerate(t)) | |
end | |
function dynamic_hmc_posterior(alg, problem, data, t, priors; kwargs...) | |
DynamicHMCPosterior(alg, problem, | |
solution -> data_log_likelihood(solution, data, t, 0.01), | |
priors, kwargs) | |
end | |
#### | |
#### ODE setup and data generation | |
#### | |
Random.seed!(1) | |
f1 = @ode_def LotkaVolterraTest1 begin | |
dx = a*x - x*y | |
dy = -3*y + x*y | |
end a | |
a₀ = 1.5 # true parameter | |
p = [a₀] | |
u0 = [1.0, 1.0] | |
tspan = (0.0,10.0) | |
prob1 = ODEProblem(f1, u0, tspan, p) | |
σ = 0.01 # noise, fixed for now | |
t = collect(range(1, stop=10,length=10)) # observation times | |
sol = solve(prob1, Tsit5()) | |
randomized = VectorOfArray([(sol(t[i]) + σ * randn(2)) for i in 1:length(t)]) | |
data = convert(Array, randomized) | |
P = dynamic_hmc_posterior(Tsit5(), prob1, data, t, (Normal(a₀, 0.1), ); maxiters = 10^5) | |
P((a = 1.5, )) # make sure log posterior works | |
#### | |
#### problem setup in the LogDensityProblems API and derivative diagnostics | |
#### | |
using LogDensityProblems | |
using LogDensityProblems: logdensity, logdensity_and_gradient | |
using PGFPlotsX # plotting | |
import ForwardDiff, Flux # AD choices | |
using TransformVariables | |
trans = as((a = asℝ₊,)) | |
ℓ = TransformedLogDensity(trans, P) | |
∇ℓ = ADgradient(:Flux, ℓ) | |
#### | |
#### diagnostics -- derivatives | |
#### | |
using FiniteDifferences | |
x_grid = range(-1, 2; length = 100) | |
x_numerical_D = map(x -> central_fdm(5, 1)(x -> logdensity(ℓ, [x]), x), x_grid) | |
x_AD_D = map(x -> first(last(logdensity_and_gradient(∇ℓ, [x]))), x_grid) | |
@pgf Axis({ xlabel = "x", ylabel = "transformed logdensity" }, | |
Plot({ no_marks }, Table(x_grid, map(x -> logdensity(ℓ, [x]), x_grid)))) | |
@pgf Axis({ legend_pos = "south west", xlabel = "transformed coordinate" }, | |
Plot({ no_marks, "red" }, Table(x_grid, x_numerical_D)), | |
LegendEntry("numerical deriv"), | |
Plot({ no_marks, "blue", dashed }, Table(x_grid, x_AD_D)), | |
LegendEntry("AD deriv")) | |
@pgf Axis({ ylabel = "AD - numerical deriv", xlabel = "transformed coordinate" }, | |
Plot({ no_marks }, Table(x_grid, x_AD_D .- x_numerical_D))) | |
@pgf Axis({ ylabel = "AD / numerical deriv - 1", xlabel = "transformed coordinate"}, | |
Plot({ no_marks }, Table(x_grid, x_AD_D ./ x_numerical_D .- 1))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment