Created
March 8, 2021 15:20
-
-
Save msakai/a7476722ddf1360a5f06af4d7d3ba75f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE FlexibleContexts #-} | |
import Control.Exception | |
import qualified Data.Vector.Generic as VG | |
import Numeric.LinearAlgebra | |
eulerMethod :: Fractional a => (a -> a -> a) -> a -> a -> a -> a | |
eulerMethod f h t y = y + h * f t y | |
-- https://ja.wikipedia.org/wiki/%E3%83%AB%E3%83%B3%E3%82%B2%EF%BC%9D%E3%82%AF%E3%83%83%E3%82%BF%E6%B3%95 | |
-- 古典的ルンゲ=クッタ法 (RK4) | |
rk4 :: Fractional a => (a -> a -> a) -> a -> a -> a -> a | |
rk4 f h t y = y + (h / 6) * sum [k1, 2*k2, 2*k3, k4] | |
where | |
k1 = f t y | |
k2 = f (t + h / 2) (y + (h / 2) * k1) | |
k3 = f (t + h / 2) (y + (h / 2) * k2) | |
k4 = f (t + h) (y + h * k3) | |
-- 古典的ルンゲ=クッタ法 (RK4) | |
rk4Vec :: (Double -> Vector Double -> Vector Double) -> Double -> Double -> Vector Double -> Vector Double | |
rk4Vec f h t y = y + scalar (h / 6) * sum [k1, 2*k2, 2*k3, k4] | |
where | |
k1 = f t y | |
k2 = f (t + h / 2) (y + scalar (h / 2) * k1) | |
k3 = f (t + h / 2) (y + scalar (h / 2) * k2) | |
k4 = f (t + h) (y + scalar h * k3) | |
type ButcherTableau = (Vector Double, Matrix Double, Vector Double) | |
checkButcherTableau :: ButcherTableau -> a -> a | |
checkButcherTableau (c, a, b) x = assert (rows a == cols a) $ assert (size c == rows a) $ assert (size b == cols a) $ x | |
isConsistent :: ButcherTableau -> Bool | |
isConsistent (c, a, _b) = and $ zipWith (==) (VG.toList c) (map VG.sum (toRows a)) | |
btForwardEuler :: ButcherTableau | |
btForwardEuler = (VG.fromList [0], (1><1) [0], VG.fromList [1]) | |
btRK4 :: ButcherTableau | |
btRK4 = | |
( VG.fromList [0, 1/2, 1/2, 1] | |
, (4 >< 4) | |
[ 0, 0, 0, 0 | |
, 1/2, 0, 0, 0 | |
, 0, 1/2, 0, 0 | |
, 0, 0, 1, 0 | |
] | |
, VG.fromList [1/6, 1/3, 1/3, 1/6] | |
) | |
explicitRungeKutta :: ButcherTableau -> (Double -> Double -> Double) -> Double -> Double -> Double -> Double | |
explicitRungeKutta (c, a, b) f h t y = y + h * sum (zipWith (*) (VG.toList b) ks) | |
where | |
ks = [f (t + h * (c ! i)) (y + h * sum [(a ! i ! j) * (ks !! j) | j <- [0..i-1]]) | i <- [0 .. VG.length c]] | |
test_sin = [(i, sin (i * 0.01)) | i <- [0..99]] | |
test_euler = take 100 $ go t0 y0 | |
where | |
f _ x = cos x | |
h = 0.01 | |
t0 = 0 | |
y0 = 0 | |
go t y = (t, y) : go t' y' | |
where | |
t' = t + h | |
y' = eulerMethod f h t y | |
test_rk4 = go t0 y0 | |
where | |
f _ x = cos x | |
h = 0.01 | |
t0 = 0 | |
y0 = 0 | |
go t y = (t, y) : go t' y' | |
where | |
t' = t + h | |
y' = rk4 f h t y | |
test_explicitRungeKutta_rk4 = go t0 y0 | |
where | |
f _ x = cos x | |
h = 0.01 | |
t0 = 0 | |
y0 = 0 | |
go t y = (t, y) : go t' y' | |
where | |
t' = t + h | |
y' = explicitRungeKutta btRK4 f h t y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment