Skip to content

Instantly share code, notes, and snippets.

@MikeInnes
Created March 17, 2018 15:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MikeInnes/8f3f7f629742c8b8c03166e49373625b to your computer and use it in GitHub Desktop.
Save MikeInnes/8f3f7f629742c8b8c03166e49373625b to your computer and use it in GitHub Desktop.
using OrdinaryDiffEq, Plots
# ––––––––––––––––––––––––––––––––––––––––––––––––––––– #
# ODE setup #
# ––––––––––––––––––––––––––––––––––––––––––––––––––––– #
# The ODE
function lotka_volterra(du,u,p,t)
x, y = u
α, β, δ, γ = p
du[1] = dx = (α - β*y)x
du[2] = dy = (δ*x - γ)y
end
const initial_pop = 1
# Solve the ODE with a given set of parameters, to see how the predator/prey
# populations behave over time.
function trajectory(predator, prey, tfinal = 10)
params = [predator..., prey...]
T = eltype(params)
u0 = T.([initial_pop, initial_pop])
tspan = (T(0), T(tfinal))
prob = ODEProblem(lotka_volterra, u0, tspan, params)
solve(prob, Tsit5(), dtmin = 1e-4)
end
# See an example solution
plot(trajectory((1.8, 1.5), (1.2, 3)), ylim=(0,6))
# For now, our loss is the deviation from the initial population;
# we are optimising for stability.
function stability(sol::ODESolution)
sol.retcode != :Success && return zero(sol.u[1][1])
series = sol.(linspace(0,10))
sum(x -> sum(x -> (x - initial_pop)^2, x), series)/length(series)
end
stability(predator, prey, tfinal = 100) =
stability(trajectory(predator, prey, tfinal))
# Preview the loss
stability((1.8,1.5),(1.2,3))
# ––––––––––––––––––––––––––––––––––––––––––––––––––––– #
# Autodiff #
# ––––––––––––––––––––––––––––––––––––––––––––––––––––– #
using Flux
import Flux.Tracker: Call, TrackedVector, track, back
using ForwardDiff: Dual, value, partials
# Hook `stability` into Flux's AD
# We use forward-mode differentiation inside the model.
function stability(predator::TrackedVector, prey)
ds = stability(Tracker.data(predator) + [Dual(0,1,0),Dual(0,0,1)], prey)
track(Call(stability, partials(ds), predator), value(ds))
end
back(::typeof(stability), Δ, ds, ps) = back(ps, Δ*ds)
# Now we can take gradients w.r.t. the parameters.
ps = param([2.2, 1.0])
l = stability(ps, (2, 3))
Flux.back!(l)
ps.grad
# ––––––––––––––––––––––––––––––––––––––––––––––––––––– #
# Optimising Parameters #
# ––––––––––––––––––––––––––––––––––––––––––––––––––––– #
predator = param([2.2, 1.0])
prey = [2, 3]
data = Iterators.repeated((), 100)
opt = ADAM([predator], 0.1)
cb = () ->
display(plot(trajectory(Flux.data(predator), prey), ylim=(0,6)))
cb()
Flux.train!(() -> stability(predator, prey),
data, opt, cb = cb)
# ––––––––––––––––––––––––––––––––––––––––––––––––––––– #
# One-shot parameter optimisation #
# ––––––––––––––––––––––––––––––––––––––––––––––––––––– #
using Distributions
# Generate prey parameters
randprey() = rand.(TruncatedNormal.([1.5,1.0], 3, 0, 10))
randprey(n) = (randprey() for _ = 1:n)
# Predict predator params from prey
paramnet = Chain(Dense(2,16,relu),Dense(16,16,relu),Dense(16,2,relu))
paramnet(randprey())
loss(prey) = stability(paramnet(prey), prey)
loss(randprey())
testloss(n=1000) = sum(loss, randprey(n))/n
testloss()
opt = ADAM(Flux.params(paramnet))
Flux.train!(loss, zip(randprey(10_000)), opt)
testloss()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment