Skip to content

Instantly share code, notes, and snippets.

@msakai
Created March 8, 2021 15:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save msakai/a7476722ddf1360a5f06af4d7d3ba75f to your computer and use it in GitHub Desktop.
Save msakai/a7476722ddf1360a5f06af4d7d3ba75f to your computer and use it in GitHub Desktop.
{-# LANGUAGE FlexibleContexts #-}
import Control.Exception
import qualified Data.Vector.Generic as VG
import Numeric.LinearAlgebra
eulerMethod :: Fractional a => (a -> a -> a) -> a -> a -> a -> a
eulerMethod f h t y = y + h * f t y
-- https://ja.wikipedia.org/wiki/%E3%83%AB%E3%83%B3%E3%82%B2%EF%BC%9D%E3%82%AF%E3%83%83%E3%82%BF%E6%B3%95
-- 古典的ルンゲ=クッタ法 (RK4)
rk4 :: Fractional a => (a -> a -> a) -> a -> a -> a -> a
rk4 f h t y = y + (h / 6) * sum [k1, 2*k2, 2*k3, k4]
where
k1 = f t y
k2 = f (t + h / 2) (y + (h / 2) * k1)
k3 = f (t + h / 2) (y + (h / 2) * k2)
k4 = f (t + h) (y + h * k3)
-- 古典的ルンゲ=クッタ法 (RK4)
rk4Vec :: (Double -> Vector Double -> Vector Double) -> Double -> Double -> Vector Double -> Vector Double
rk4Vec f h t y = y + scalar (h / 6) * sum [k1, 2*k2, 2*k3, k4]
where
k1 = f t y
k2 = f (t + h / 2) (y + scalar (h / 2) * k1)
k3 = f (t + h / 2) (y + scalar (h / 2) * k2)
k4 = f (t + h) (y + scalar h * k3)
type ButcherTableau = (Vector Double, Matrix Double, Vector Double)
checkButcherTableau :: ButcherTableau -> a -> a
checkButcherTableau (c, a, b) x = assert (rows a == cols a) $ assert (size c == rows a) $ assert (size b == cols a) $ x
isConsistent :: ButcherTableau -> Bool
isConsistent (c, a, _b) = and $ zipWith (==) (VG.toList c) (map VG.sum (toRows a))
btForwardEuler :: ButcherTableau
btForwardEuler = (VG.fromList [0], (1><1) [0], VG.fromList [1])
btRK4 :: ButcherTableau
btRK4 =
( VG.fromList [0, 1/2, 1/2, 1]
, (4 >< 4)
[ 0, 0, 0, 0
, 1/2, 0, 0, 0
, 0, 1/2, 0, 0
, 0, 0, 1, 0
]
, VG.fromList [1/6, 1/3, 1/3, 1/6]
)
explicitRungeKutta :: ButcherTableau -> (Double -> Double -> Double) -> Double -> Double -> Double -> Double
explicitRungeKutta (c, a, b) f h t y = y + h * sum (zipWith (*) (VG.toList b) ks)
where
ks = [f (t + h * (c ! i)) (y + h * sum [(a ! i ! j) * (ks !! j) | j <- [0..i-1]]) | i <- [0 .. VG.length c]]
test_sin = [(i, sin (i * 0.01)) | i <- [0..99]]
test_euler = take 100 $ go t0 y0
where
f _ x = cos x
h = 0.01
t0 = 0
y0 = 0
go t y = (t, y) : go t' y'
where
t' = t + h
y' = eulerMethod f h t y
test_rk4 = go t0 y0
where
f _ x = cos x
h = 0.01
t0 = 0
y0 = 0
go t y = (t, y) : go t' y'
where
t' = t + h
y' = rk4 f h t y
test_explicitRungeKutta_rk4 = go t0 y0
where
f _ x = cos x
h = 0.01
t0 = 0
y0 = 0
go t y = (t, y) : go t' y'
where
t' = t + h
y' = explicitRungeKutta btRK4 f h t y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment