Skip to content

Instantly share code, notes, and snippets.

@feanor12
Last active November 12, 2022 16:01
Show Gist options
  • Save feanor12/867f0b04a732824f3bff6846899bb167 to your computer and use it in GitHub Desktop.
Save feanor12/867f0b04a732824f3bff6846899bb167 to your computer and use it in GitHub Desktop.
using UnicodePlots
using OrdinaryDiffEq
using Parameters
struct Numerov <: OrdinaryDiffEqAlgorithm end
mutable struct NumerovCache <: OrdinaryDiffEq.OrdinaryDiffEqMutableCache
uprev # last u value
g
gprev
s
sprev
end
function OrdinaryDiffEq.alg_cache(alg::Numerov,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
NumerovCache(0,0,0,0,0)
end
OrdinaryDiffEq.alg_order(::Numerov) = 4
OrdinaryDiffEq.isfsal(::Numerov) = false
function OrdinaryDiffEq.initialize!(integrator, cache::NumerovCache)
@unpack u,dt,t = integrator
@unpack s,g = integrator.f.f1.f
# estimate uprev by taylor series
cache.uprev= u[1] -
u[2]*dt -
dt^2/2*(s(t)+u[1]*g(t))
#populate cache
cache.g = g(t)
cache.gprev = g(t-dt)
cache.s = s(t)
cache.sprev = s(t-dt)
end
function OrdinaryDiffEq.perform_step!(integrator, cache::NumerovCache)
u = integrator.u[1]
@unpack t,tprev,dt,f = integrator
@unpack g,s,gprev,sprev,uprev = cache
gfun = f.f1.f.g
sfun = f.f1.f.s
tnext = dt+t
gnext = gfun(tnext)
snext = sfun(tnext)
# calculate next point (explicit numerov)
unext = (2*u*(1-(5*dt^2*g)/12) -
uprev*(1+(dt^2*gprev)/12) +
dt^2/12*(snext+10*s+sprev))/
(1+dt^2/12*gnext)
# save next and prev point
cache.uprev = integrator.u[1]
integrator.u[1] = unext
cache.gprev = g
cache.g = gnext
cache.sprev = s
cache.s = snext
end
# save part of the linear ode
# ddu(t) = g(t)*u(t)+s(t)
struct SONFOFunction
g
s
end
# make it callable, may not be needed
(f::SONFOFunction)(dv,v,u,p,t) = dv[1] = f.g(t)*u+s(t)
function main(g,s,x1=100)
# test
prob = SecondOrderODEProblem(SONFOFunction(x->g,x->s),[0.],[0.],(0.,10))
sol = solve(prob,Numerov(),dt = 1//10, dense=false, adaptive=false)
# solution for g,s = const and u0,du0 = 0
y = (broadcast(cos,sol.t.*sqrt(g)).-1) .*(-s/g)
#compare
p = lineplot(sol.t,map(x->x[1],sol.u))
lineplot!(p,sol.t,y)
end
main(rand(2)...)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment