Created
March 5, 2021 03:20
-
-
Save jiweiqi/4a485b90c5f706eeacd1e7f150a9688d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using OrdinaryDiffEq, Flux, Random, Plots | |
using DiffEqSensitivity | |
using ForwardDiff | |
using LinearAlgebra, Statistics | |
using ProgressBars, Printf | |
using Flux.Optimise: update!, ExpDecay | |
using Flux.Losses: mae, mse | |
using BSON: @save, @load | |
using MINPACK | |
Random.seed!(1234); | |
################################### | |
# Arguments | |
is_restart = false; | |
n_epoch = 10000; | |
n_plot = 50; | |
datasize = 10; | |
tstep = 1; | |
n_exp_train = 30; | |
n_exp_test = 5; | |
n_exp = n_exp_train + n_exp_test; | |
noise = 0.05; | |
ns = 6; | |
nr = 3; | |
alg = AutoTsit5(Rosenbrock23(autodiff=true)); | |
atol = 1e-6; | |
rtol = 1e-3; | |
# opt = ADAMW(5.f-3, (0.9, 0.999), 1.f-6); | |
opt = Flux.Optimiser(ExpDecay(5e-3, 0.5, 500 * n_exp_train, 1e-4), | |
ADAMW(0.005, (0.9, 0.999), 1.f-6)); | |
const lb = atol; | |
#################################### | |
function trueODEfunc(dydt, y, k, t) | |
# TG(1),ROH(2),DG(3),MG(4),GL(5),R'CO2R(6) | |
r1 = k[1] * y[1] * y[2]; | |
r2 = k[2] * y[3] * y[2]; | |
r3 = k[3] * y[4] * y[2]; | |
dydt[1] = - r1; # TG | |
dydt[2] = - r1 - r2 - r3; # TG | |
dydt[3] = r1 - r2; # DG | |
dydt[4] = r2 - r3; # MG | |
dydt[5] = r3; # GL | |
dydt[6] = r1 + r2 + r3; # R'CO2R | |
dydt[7] = 0.f0; | |
end | |
logA = Float32[18.60f0, 19.13f0, 7.93f0]; | |
Ea = Float32[14.54f0, 14.42f0, 6.47f0]; # kcal/mol | |
function Arrhenius(logA, Ea, T) | |
R = 1.98720425864083f-3 | |
k = exp.(logA) .* exp.(-Ea ./ R ./ T) | |
return k | |
end | |
# Generate datasets | |
u0_list = rand(Float32, (n_exp, ns + 1)); | |
u0_list[:, 1:2] = u0_list[:, 1:2] .* 2.0 .+ 0.2; | |
u0_list[:, 3:ns] .= 0.0; | |
u0_list[:, ns + 1] = u0_list[:, ns + 1] .* 20.0 .+ 323.0; # T[K] | |
tspan = Float32[0.0, datasize * tstep]; | |
tsteps = range(tspan[1], tspan[2], length=datasize); | |
ode_data_list = zeros(Float32, (n_exp, ns, datasize)); | |
yscale_list = []; | |
function max_min(ode_data) | |
return maximum(ode_data, dims=2) .- minimum(ode_data, dims=2) .+ lb | |
end | |
for i in 1:n_exp | |
u0 = u0_list[i, :] | |
k = Arrhenius(logA, Ea, u0[end]) | |
prob_trueode = ODEProblem(trueODEfunc, u0, tspan, k) | |
ode_data = Array(solve(prob_trueode, alg, saveat=tsteps))[1:end - 1, :] | |
ode_data .+= randn(size(ode_data)) .* ode_data .* noise | |
ode_data_list[i, :, :] .= ode_data | |
push!(yscale_list, max_min(ode_data)) | |
end | |
yscale = maximum(hcat(yscale_list...), dims=2); | |
np = nr * (ns + 2) + 1; | |
p = randn(Float32, np) .* 0.1; | |
p[1:nr] .+= 0.8; | |
p[nr * (ns + 1) + 1:nr * (ns + 2)] .+= 0.8; | |
p[end] = 0.1; | |
function p2vec(p) | |
slope = p[nr * (ns + 2) + 1] .* 10 | |
w_b = p[1:nr] .* slope | |
w_b = clamp.(w_b, -5.0, 10.0) | |
w_out = reshape(p[nr + 1:nr * (ns + 1)], ns, nr) | |
w_in_Ea = abs.(p[nr * (ns + 1) + 1:nr * (ns + 2)] .* slope) | |
w_in_Ea = clamp.(w_in_Ea, 1.0, 40.0) | |
w_in = clamp.(-w_out, 0.0, 4.0) | |
w_in = vcat(w_in, w_in_Ea') | |
return w_in, w_b, w_out | |
end | |
function display_p(p) | |
w_in, w_b, w_out = p2vec(p); | |
println("species (column) reaction (row)") | |
println("w_in | w_b") | |
w_in_ = vcat(w_in, w_b')' | |
show(stdout, "text/plain", round.(w_in_, digits=3)) | |
println("\nw_out") | |
show(stdout, "text/plain", round.(w_out', digits=3)) | |
println("\n") | |
end | |
display_p(p) | |
const inv_R = - 1 / 1.98720425864083f-3; | |
function crnn!(du, u, p, t) | |
logX = @. log(clamp(u, lb, Inf)) | |
w_in_x = w_in' * vcat(logX, inv_R / T) | |
du .= w_out * (@. exp(w_in_x + w_b)) | |
end | |
u0 = u0_list[1, :]; | |
prob = ODEProblem(crnn!, u0[1:end-1], tspan, saveat=tsteps, atol=atol, rtol=rtol) | |
# sense = BacksolveAdjoint(checkpointing=true; autojacvec=ZygoteVJP()); | |
sense = ForwardDiffSensitivity() | |
function predict_neuralode(u0, p) | |
global w_in, w_b, w_out = p2vec(p) | |
global T = u0[end] | |
pred = Array(solve(prob, alg, u0=@view(u0[1:end-1]), | |
p=p, sensalg=sense, maxiter=1000)) | |
return pred | |
end | |
predict_neuralode(u0, p) | |
function loss_neuralode(p, i_exp) | |
ode_data = @view(ode_data_list[i_exp, :, :]) | |
pred = predict_neuralode(@view(u0_list[i_exp, :]), p) | |
loss = mae(ode_data ./ yscale, pred ./ yscale) | |
return loss | |
end | |
cbi = function (p, i_exp) | |
ode_data = ode_data_list[i_exp, :, :] | |
pred = predict_neuralode(u0_list[i_exp, :], p) | |
l_plt = [] | |
for i in 1:ns | |
plt = scatter(tsteps, ode_data[i,:], markercolor=:transparent, | |
title=string(i), label=string("data_", i)) | |
plot!(plt, tsteps, pred[i,:], label=string("pred_", i)) | |
push!(l_plt, plt) | |
end | |
plt_all = plot(l_plt..., legend=false) | |
png(plt_all, string("figs/i_exp_", i_exp)) | |
return false | |
end | |
l_loss_train = [] | |
l_loss_val = [] | |
iter = 1 | |
cb = function (p, loss_train, loss_val) | |
global l_loss_train, l_loss_val, iter | |
push!(l_loss_train, loss_train) | |
push!(l_loss_val, loss_val) | |
if iter % n_plot == 0 | |
display_p(p) | |
@printf("min loss train %.4e val %.4e\n", minimum(l_loss_train), minimum(l_loss_val)) | |
l_exp = randperm(n_exp)[1:1]; | |
println("update plot for ", l_exp) | |
for i_exp in l_exp | |
cbi(p, i_exp) | |
end | |
plt_loss = plot(l_loss_train, xscale=:log10, yscale=:log10, | |
framestyle=:box, label="Training") | |
plot!(plt_loss, l_loss_val, label="Validation") | |
plot!(xlabel="Epoch", ylabel="Loss") | |
png(plt_loss, "figs/loss") | |
@save "./checkpoint/mymodel.bson" p opt l_loss_train l_loss_val iter; | |
end | |
iter += 1; | |
end | |
if is_restart | |
@load "./checkpoint/mymodel.bson" p opt l_loss_train l_loss_val iter; | |
iter += 1; | |
end | |
function plot_trace(res) | |
l_loss = ones(res.trace.f_calls, 4) | |
for i = 1:res.trace.f_calls | |
l_loss[i, 1] = res.trace.trace[i].iteration | |
l_loss[i, 2] = res.trace.trace[i].fnorm | |
l_loss[i, 3] = res.trace.trace[i].xnorm | |
l_loss[i, 4] = res.trace.trace[i].step_time | |
end | |
l_plt = [] | |
plt = plot( | |
l_loss[:, 1], | |
l_loss[:, 2], | |
xscale = :identity, | |
yscale = :log10, | |
label = "fnorm", | |
); | |
push!(l_plt, plt) | |
plt = plot( | |
l_loss[:, 1], | |
l_loss[:, 3] .+ 1.e-6, | |
xscale = :identity, | |
yscale = :log10, | |
label = "xnorm", | |
); | |
push!(l_plt, plt) | |
plt = plot(l_plt...) | |
png(plt, "figs/lm_loss") | |
end | |
function f!(fvec, p) | |
fvec .= [loss_neuralode(p, i) for i in 1:n_exp_train] | |
return fvec | |
end | |
fvec = zeros(n_exp_train) | |
@time f!(fvec, p) | |
function g!(fjac, p) | |
for i_exp in 1:n_exp_train | |
fjac[i_exp, :] .= ForwardDiff.gradient(x -> loss_neuralode(x, i_exp), p) | |
end | |
return fjac | |
end | |
fjac = zeros(n_exp_train, length(p)) | |
@time g!(fjac, p) | |
i_exp = 1 | |
epochs = ProgressBar(iter:n_epoch); | |
loss_epoch = zeros(Float32, n_exp); | |
grad_norm = zeros(Float32, n_exp_train); | |
for epoch in epochs | |
global p | |
for i_exp in randperm(n_exp_train) | |
grad = ForwardDiff.gradient(x -> loss_neuralode(x, i_exp), p) | |
grad_norm[i_exp] = norm(grad, 2) | |
update!(opt, p, grad) | |
end | |
for i_exp in 1:n_exp | |
loss_epoch[i_exp] = loss_neuralode(p, i_exp) | |
end | |
loss_train = mean(loss_epoch[1:n_exp_train]); | |
loss_val = mean(loss_epoch[n_exp_train + 1:end]); | |
set_description(epochs, string(@sprintf("Loss train %.2e val %.2e gnorm %.1e lr %.1e", | |
loss_train, loss_val, mean(grad_norm), opt[1].eta))) | |
cb(p, loss_train, loss_val); | |
end | |
p0 = Float64.(p); | |
m = n_exp_train | |
res = fsolve(f!, g!, p0, m, | |
iterations = 2000, tol = 1e-8, | |
show_trace = true, tracing = true; method = :lm) | |
plot_trace(res) | |
display_p(res.x) | |
for i_exp in 1:5 | |
cbi(res.x, i_exp) | |
end | |
# for i_exp in 1:n_exp | |
# cbi(p, i_exp) | |
# end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment