Skip to content

Instantly share code, notes, and snippets.

@ranjanan
Last active July 2, 2019 04:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ranjanan/6f4d1ec5288dbabee04c49201c5f6abb to your computer and use it in GitHub Desktop.
Save ranjanan/6f4d1ec5288dbabee04c49201c5f6abb to your computer and use it in GitHub Desktop.
Koopman
# Add this first: https://github.com/ranjanan/MonteCarloIntegration.jl
using Cubature
# using Cuba
using MonteCarloIntegration
# include("vegas.jl")
function koopman(g,prob,u0,p,args...;kwargs...)
g(solve(remake(prob,u0=u0,p=p),args...;kwargs...))
end
function koopman_cost(u0s,ps,g,prob,args...;maxevals=0,
ireltol = 1e-2, iabstol=1e-2, use_vegas=false, kwargs...)
n = length(u0s)
function _f(x)
u0 = x[1:n]
p = x[n+1:end]
k = koopman(g,prob,u0,p,args...;kwargs...)
w = prod(pdf(a,b) for (a,b) in zip(u0s,u0))*
prod(pdf(a,b) for (a,b) in zip(ps,p))
k*w
end
xs = [u0s;ps]
st = minimum.(xs)
en = maximum.(xs)
@show st
@show en
if use_vegas
# vegas((x,f) -> f[1] = _f(st .+ x ./ (en .- st)) * sum(en .- st), rtol = ireltol)
vegas(_f, st, en, rtol = ireltol)
else
hcubature(_f, minimum.(xs), maximum.(xs);
reltol=ireltol, abstol=iabstol, maxevals = maxevals)
end
end
function montecarlo_cost(u0s,ps,g,prob,args...;num_monte,kwargs...)
prob_func = function (prob,i,repeat)
remake(prob,u0=rand.(u0s),p=rand.(ps))
end
output_func = (sol,i) -> (g(sol),false)
monte_prob = MonteCarloProblem(prob;
output_func = output_func,
prob_func = prob_func)
mean(solve(monte_prob,args...;num_monte=num_monte,kwargs...).u)
end
using OrdinaryDiffEq, DiffEqMonteCarlo, Distributions, Test
include("koopman.jl")
function f(du,u,p,t)
du[1] = dx = p[1]*u[1] - u[1]*u[2]
du[2] = dy = -3*u[2] + u[1]*u[2]
end
u0 = [1.0;1.0]
tspan = (0.0,10.0)
p = [1.5]
prob = ODEProblem(f,u0,tspan,p)
sol = solve(remake(prob,u0=u0),Tsit5())
cost(sol) = sum(max(x[1]-12,0) for x in sol.u)
u0s = [Uniform(0.25,5.5),Uniform(0.25,5.5)]
ps = [Uniform(0.5,2.0)]
@time c1, _ = koopman_cost(u0s, ps, cost, prob, Tsit5();saveat=0.1, use_vegas = true)
@time c2 = montecarlo_cost(u0s, ps, cost, prob, Tsit5(); num_monte = 100000, saveat = 0.1)
@show c1, c2
function f2(du,u,p,t)
du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
end
u0 = [1.0;1.0]
tspan = (0.0,10.0)
p = [1.5,1.0,3.0,1.0]
prob = ODEProblem(f2,u0,tspan,p)
cost(sol) = sum(max(x[1]-6,0) for x in sol.u)
u0s = [Uniform(0.25,5.5),Uniform(0.25,5.5)]
ps = [Uniform(0.5,2.0), Uniform(0.5, 1.5), Uniform(2.5, 3.5), Uniform(0.5, 1.5)]
@time c1, _ = koopman_cost(u0s, ps, cost, prob, Tsit5();saveat=0.1, use_vegas = true)
@time c2 = montecarlo_cost(u0s, ps, cost, prob, Tsit5(); num_monte = 100000, saveat = 0.1)
@show c1, c2
julia> include("koopmantest.jl")
st = [0.25, 0.25, 0.5]
en = [5.5, 5.5, 2.0]
abs(sd / Itot) = 0.006173570557686897
abs(sd / Itot) = 0.004568674209708778
abs(sd / Itot) = 0.0037499627643836657
abs(sd / Itot) = 0.0032123856396025796
abs(sd / Itot) = 0.002852758061603755
abs(sd / Itot) = 0.002586113950575031
abs(sd / Itot) = 0.002415043588973447
abs(sd / Itot) = 0.0022652010789296333
abs(sd / Itot) = 0.002137917763744343
abs(sd / Itot) = 0.0020213454962229537
abs(sd / Itot) = 0.0019128086555385388
abs(sd / Itot) = 0.0018464460888589287
abs(sd / Itot) = 0.0017735299224410674
abs(sd / Itot) = 0.0016984562630878756
abs(sd / Itot) = 0.0016389966219611401
abs(sd / Itot) = 0.0015850539304377834
abs(sd / Itot) = 0.0015391039085527094
abs(sd / Itot) = 0.0014986336554413061
abs(sd / Itot) = 0.0014541671657582061
abs(sd / Itot) = 0.001415779847375256
abs(sd / Itot) = 0.0013840192895460237
abs(sd / Itot) = 0.00135155496477406
abs(sd / Itot) = 0.0013235528562994454
abs(sd / Itot) = 0.0012907821104183865
abs(sd / Itot) = 0.001266662939531453
abs(sd / Itot) = 0.0012405483979915698
abs(sd / Itot) = 0.0012192766216697014
abs(sd / Itot) = 0.001199623977090778
abs(sd / Itot) = 0.001179118987334562
abs(sd / Itot) = 0.0011591704487498753
abs(sd / Itot) = 0.0011423859073484346
abs(sd / Itot) = 0.0011263584112780342
abs(sd / Itot) = 0.0011098896932083943
abs(sd / Itot) = 0.0010943620659613173
abs(sd / Itot) = 0.0010781267033999117
abs(sd / Itot) = 0.00106303452613052
abs(sd / Itot) = 0.001048623925630123
abs(sd / Itot) = 0.0010337859832191007
abs(sd / Itot) = 0.0010187342473147664
abs(sd / Itot) = 0.001005407203540136
abs(sd / Itot) = 0.0009923964731765596
abs(sd / Itot) = 0.0009783733262473534
abs(sd / Itot) = 0.0009669880454430367
abs(sd / Itot) = 0.0009561931638116755
abs(sd / Itot) = 0.0009462842480821386
abs(sd / Itot) = 0.0009344479048989255
abs(sd / Itot) = 0.0009249089677737511
abs(sd / Itot) = 0.0009150597029660239
abs(sd / Itot) = 0.0009044090003748447
abs(sd / Itot) = 0.0008947423960952532
abs(sd / Itot) = 0.0008865863025954616
abs(sd / Itot) = 0.0008774272917221622
abs(sd / Itot) = 0.0008694275209260272
abs(sd / Itot) = 0.0008624958775622968
abs(sd / Itot) = 0.0008552162436262015
abs(sd / Itot) = 0.0008476483237161635
abs(sd / Itot) = 0.0008398192610376911
abs(sd / Itot) = 0.0008326171057483726
abs(sd / Itot) = 0.0008253738608538049
abs(sd / Itot) = 0.0008178870077533955
abs(sd / Itot) = 0.0008115892782157336
abs(sd / Itot) = 0.0008050515257036761
abs(sd / Itot) = 0.0007984944830077183
abs(sd / Itot) = 0.0007920406080545369
abs(sd / Itot) = 0.0007850700432384622
abs(sd / Itot) = 0.0007794180560597966
abs(sd / Itot) = 0.0007736774497712322
abs(sd / Itot) = 0.0007681874429611999
abs(sd / Itot) = 0.0007624142702241227
abs(sd / Itot) = 0.0007579521167451838
abs(sd / Itot) = 0.0007518112187894674
abs(sd / Itot) = 0.0007457554455369249
abs(sd / Itot) = 0.0007405627056401335
abs(sd / Itot) = 0.0007357119940808038
abs(sd / Itot) = 0.0007306666917066382
abs(sd / Itot) = 0.0007255360727218986
abs(sd / Itot) = 0.0007210553991851174
abs(sd / Itot) = 0.0007165407525876763
abs(sd / Itot) = 0.0007122668681230613
abs(sd / Itot) = 0.0007073305942862755
abs(sd / Itot) = 0.0007031380952039414
abs(sd / Itot) = 0.0006989581173033717
abs(sd / Itot) = 0.0006952870540037916
abs(sd / Itot) = 0.0006908649570204112
abs(sd / Itot) = 0.000686696820667498
abs(sd / Itot) = 0.0006827563224878402
abs(sd / Itot) = 0.000678927989743132
abs(sd / Itot) = 0.0006752868829396415
abs(sd / Itot) = 0.0006720943495040623
abs(sd / Itot) = 0.0006686525623632535
abs(sd / Itot) = 0.0006652727788400343
abs(sd / Itot) = 0.0006622433804475721
abs(sd / Itot) = 0.0006587491207566949
abs(sd / Itot) = 0.0006546807636582385
abs(sd / Itot) = 0.0006513808274262075
abs(sd / Itot) = 0.0006480768974690109
abs(sd / Itot) = 0.0006447556865493701
abs(sd / Itot) = 0.0006416990961095163
abs(sd / Itot) = 0.0006384612550085456
abs(sd / Itot) = 0.0006353777418388643
nevals = 10000000
619.765409 seconds (2.09 G allocations: 199.873 GiB, 5.75% gc time)
13.237179 seconds (23.13 M allocations: 2.099 GiB, 8.46% gc time)
(c1, c2) = (0.046188552382282484, 0.04996642050491786)
(0.046188552382282484, 0.04996642050491786)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment