Skip to content

Instantly share code, notes, and snippets.

@tpapp
Created August 13, 2019 08:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tpapp/569bdef475c0a44c46d5c9e6d5e9e493 to your computer and use it in GitHub Desktop.
Save tpapp/569bdef475c0a44c46d5c9e6d5e9e493 to your computer and use it in GitHub Desktop.
MWE for DiffEq AD mismatch
#####
##### Self-contained example to dissect broken tests in
##### https://github.com/JuliaDiffEq/DiffEqBayes.jl/blob/e033307768892e2f7242ae0aab3e09ec4819c11b/test/dynamicHMC.jl
#####
##### NOTE: you need DynamicHMC#master and LogDensityProblems#0.9.x for this to work
####
#### part that just relies on the DiffEq ecosystem
####
using OrdinaryDiffEq, ParameterizedFunctions, RecursiveArrayTools, Parameters,
Distributions, Random
struct DynamicHMCPosterior{TA,TP,TL,TR,TK}
alg::TA
problem::TP
likelihood::TL
priors::TR
kwargs::TK
end
function (P::DynamicHMCPosterior)(θ)
@unpack a = θ
@unpack alg, problem, likelihood, priors, kwargs = P
prob = remake(problem,u0 = convert.(eltype(a), problem.u0), p = a)
sol = solve(prob, alg; kwargs...)
if any((s.retcode != :Success for s in sol)) && any((s.retcode != :Terminated for s in sol))
return -Inf
end
likelihood(sol) + mapreduce(logpdf, +, priors, θ)
end
function data_log_likelihood(solution, data, t, σ)
sum(sum(logpdf.(Normal(0.0, σ), solution(t) .- data[:, i])) for (i, t) in enumerate(t))
end
function dynamic_hmc_posterior(alg, problem, data, t, priors; kwargs...)
DynamicHMCPosterior(alg, problem,
solution -> data_log_likelihood(solution, data, t, 0.01),
priors, kwargs)
end
####
#### ODE setup and data generation
####
Random.seed!(1)
f1 = @ode_def LotkaVolterraTest1 begin
dx = a*x - x*y
dy = -3*y + x*y
end a
a₀ = 1.5 # true parameter
p = [a₀]
u0 = [1.0, 1.0]
tspan = (0.0,10.0)
prob1 = ODEProblem(f1, u0, tspan, p)
σ = 0.01 # noise, fixed for now
t = collect(range(1, stop=10,length=10)) # observation times
sol = solve(prob1, Tsit5())
randomized = VectorOfArray([(sol(t[i]) + σ * randn(2)) for i in 1:length(t)])
data = convert(Array, randomized)
P = dynamic_hmc_posterior(Tsit5(), prob1, data, t, (Normal(a₀, 0.1), ); maxiters = 10^5)
P((a = 1.5, )) # make sure log posterior works
####
#### problem setup in the LogDensityProblems API and derivative diagnostics
####
using LogDensityProblems
using LogDensityProblems: logdensity, logdensity_and_gradient
using PGFPlotsX # plotting
import ForwardDiff, Flux # AD choices
using TransformVariables
trans = as((a = asℝ₊,))
ℓ = TransformedLogDensity(trans, P)
∇ℓ = ADgradient(:Flux, ℓ)
####
#### diagnostics -- derivatives
####
using FiniteDifferences
x_grid = range(-1, 2; length = 100)
x_numerical_D = map(x -> central_fdm(5, 1)(x -> logdensity(ℓ, [x]), x), x_grid)
x_AD_D = map(x -> first(last(logdensity_and_gradient(∇ℓ, [x]))), x_grid)
@pgf Axis({ xlabel = "x", ylabel = "transformed logdensity" },
Plot({ no_marks }, Table(x_grid, map(x -> logdensity(ℓ, [x]), x_grid))))
@pgf Axis({ legend_pos = "south west", xlabel = "transformed coordinate" },
Plot({ no_marks, "red" }, Table(x_grid, x_numerical_D)),
LegendEntry("numerical deriv"),
Plot({ no_marks, "blue", dashed }, Table(x_grid, x_AD_D)),
LegendEntry("AD deriv"))
@pgf Axis({ ylabel = "AD - numerical deriv", xlabel = "transformed coordinate" },
Plot({ no_marks }, Table(x_grid, x_AD_D .- x_numerical_D)))
@pgf Axis({ ylabel = "AD / numerical deriv - 1", xlabel = "transformed coordinate"},
Plot({ no_marks }, Table(x_grid, x_AD_D ./ x_numerical_D .- 1)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment