Skip to content

Instantly share code, notes, and snippets.

@drozzy
Created October 22, 2020 21:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save drozzy/a4fb9bab9303bb63784f158facb5a974 to your computer and use it in GitHub Desktop.
Save drozzy/a4fb9bab9303bb63784f158facb5a974 to your computer and use it in GitHub Desktop.
using DifferentialEquations, Flux, Optim, DiffEqFlux, DiffEqSensitivity, Plots
using Plots
function reaction1!(du, u, p, t)
# Reaction: A + 2B → C
a, b, c = u
k, = p
rate = k * a*b^2
du[1] = da = -rate
du[2] = db = -2*rate
du[3] = dc = rate
end
u0 = [2.0, 2.4, 1.0]
t_start = 0.0
t_end = 10.0
t_span = (t_start, t_end)
t_samples = range(t_start, t_end, length=5)
t_plot = range(t_start, t_end, length=100)
p = [1.1]
prob = ODEProblem(reaction1!, u0, t_span, p)
sol = solve(prob, Tsit5())
data = reduce(hcat, [sol(t) for t in t_samples])
function make_plot()
dp = [sol(t) for t in t_plot]
dp = reduce(hcat, dp)
a,b,c = dp[1,:], dp[2, :], dp[3, :]
plt = plot(t_plot, [a,b,c], linestyle=:dash, label=["True A" "True B" "True C"], legend = :outertopright)
ds = [sol(t) for t in t_samples]
ds = reduce(hcat, ds)
a,b,c = ds[1,:], ds[2, :], ds[3, :]
plt = scatter!(plt, t_samples, [a,b,c], color=:gray, label=["Sampled A" "Sampled B" "Sampled C"])
plot!(xlims = (-1 + t_start, t_end+1), ylims = (0.0, 3))
end
function make_plot(prediction::Array)
a,b,c = prediction[1,:], prediction[2, :], prediction[3, :]
scatter!(make_plot(), t_samples, [a,b,c], label=["Pred A" "Pred B" "Pred C"])
end
function make_plot(prediction::DiffEqBase.AbstractODESolution)
times = t_plot
predictions = [prediction(t) for t in times]
a = [predictions[i][1] for i in 1:size(predictions, 1)]
b = [predictions[i][2] for i in 1:size(predictions, 1)]
c = [predictions[i][3] for i in 1:size(predictions, 1)]
plot!(make_plot(), times, [a,b,c], label=["Pred A" "Pred B" "Pred C"])
end
make_plot()
############ TRAIN ################
dRdt = FastChain((x, p) -> x,
FastDense(3, 50, tanh),
FastDense(50, 3))
neuralode = NeuralODE(dRdt, t_span, saveat = t_samples)
function loss(p)
prediction = Array(neuralode(u0, p))
print(typeof(prediction))
l = sum(abs2, data .- prediction)
return l, prediction
end
loss(neuralode.p)
callback = function (p, l, prediction; doplot = true)
display(l)
if doplot
plt = make_plot(prediction)
display(plot(plt))
end
return false
end
result = DiffEqFlux.sciml_train(loss, neuralode.p,
ADAM(0.01),
cb = callback,
maxiters = 250)
pred = neuralode(u0)
predictions = [pred(t) for t in t_plot]
a = [predictions[i][1] for i in 1:size(predictions, 1)]
b = [predictions[i][2] for i in 1:size(predictions, 1)]
c = [predictions[i][3] for i in 1:size(predictions, 1)]
plot!(t_plot, [a,b,c], label=["Pred A2" "Pred B2" "Pred C2"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment