-
-
Save MikeInnes/8f3f7f629742c8b8c03166e49373625b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using OrdinaryDiffEq, Plots | |
# ––––––––––––––––––––––––––––––––––––––––––––––––––––– # | |
# ODE setup # | |
# ––––––––––––––––––––––––––––––––––––––––––––––––––––– # | |
# The ODE | |
function lotka_volterra(du,u,p,t) | |
x, y = u | |
α, β, δ, γ = p | |
du[1] = dx = (α - β*y)x | |
du[2] = dy = (δ*x - γ)y | |
end | |
const initial_pop = 1 | |
# Solve the ODE with a given set of parameters, to see how the predator/prey | |
# populations behave over time. | |
function trajectory(predator, prey, tfinal = 10) | |
params = [predator..., prey...] | |
T = eltype(params) | |
u0 = T.([initial_pop, initial_pop]) | |
tspan = (T(0), T(tfinal)) | |
prob = ODEProblem(lotka_volterra, u0, tspan, params) | |
solve(prob, Tsit5(), dtmin = 1e-4) | |
end | |
# See an example solution | |
plot(trajectory((1.8, 1.5), (1.2, 3)), ylim=(0,6)) | |
# For now, our loss is the deviation from the initial population; | |
# we are optimising for stability. | |
function stability(sol::ODESolution) | |
sol.retcode != :Success && return zero(sol.u[1][1]) | |
series = sol.(linspace(0,10)) | |
sum(x -> sum(x -> (x - initial_pop)^2, x), series)/length(series) | |
end | |
stability(predator, prey, tfinal = 100) = | |
stability(trajectory(predator, prey, tfinal)) | |
# Preview the loss | |
stability((1.8,1.5),(1.2,3)) | |
# ––––––––––––––––––––––––––––––––––––––––––––––––––––– # | |
# Autodiff # | |
# ––––––––––––––––––––––––––––––––––––––––––––––––––––– # | |
using Flux | |
import Flux.Tracker: Call, TrackedVector, track, back | |
using ForwardDiff: Dual, value, partials | |
# Hook `stability` into Flux's AD | |
# We use forward-mode differentiation inside the model. | |
function stability(predator::TrackedVector, prey) | |
ds = stability(Tracker.data(predator) + [Dual(0,1,0),Dual(0,0,1)], prey) | |
track(Call(stability, partials(ds), predator), value(ds)) | |
end | |
back(::typeof(stability), Δ, ds, ps) = back(ps, Δ*ds) | |
# Now we can take gradients w.r.t. the parameters. | |
ps = param([2.2, 1.0]) | |
l = stability(ps, (2, 3)) | |
Flux.back!(l) | |
ps.grad | |
# ––––––––––––––––––––––––––––––––––––––––––––––––––––– # | |
# Optimising Parameters # | |
# ––––––––––––––––––––––––––––––––––––––––––––––––––––– # | |
predator = param([2.2, 1.0]) | |
prey = [2, 3] | |
data = Iterators.repeated((), 100) | |
opt = ADAM([predator], 0.1) | |
cb = () -> | |
display(plot(trajectory(Flux.data(predator), prey), ylim=(0,6))) | |
cb() | |
Flux.train!(() -> stability(predator, prey), | |
data, opt, cb = cb) | |
# ––––––––––––––––––––––––––––––––––––––––––––––––––––– # | |
# One-shot parameter optimisation # | |
# ––––––––––––––––––––––––––––––––––––––––––––––––––––– # | |
using Distributions | |
# Generate prey parameters | |
randprey() = rand.(TruncatedNormal.([1.5,1.0], 3, 0, 10)) | |
randprey(n) = (randprey() for _ = 1:n) | |
# Predict predator params from prey | |
paramnet = Chain(Dense(2,16,relu),Dense(16,16,relu),Dense(16,2,relu)) | |
paramnet(randprey()) | |
loss(prey) = stability(paramnet(prey), prey) | |
loss(randprey()) | |
testloss(n=1000) = sum(loss, randprey(n))/n | |
testloss() | |
opt = ADAM(Flux.params(paramnet)) | |
Flux.train!(loss, zip(randprey(10_000)), opt) | |
testloss() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment