Skip to content

Instantly share code, notes, and snippets.

@nihemak
Last active December 27, 2015 09:19
Show Gist options
  • Save nihemak/7303256 to your computer and use it in GitHub Desktop.
Save nihemak/7303256 to your computer and use it in GitHub Desktop.
http://nineties.github.io/math-seminar/7.html 練習問題(ガウス消去法/ピボット選択)をHaskellでmutableなarrayを使用して書いてみた
{-
- http://nineties.github.io/math-seminar/7.html
- 練習問題(ガウス消去法/ピボット選択)をHaskellで書いてみた
- (mutableなarrayを使用)
-}
import Data.Array
import Data.Array.ST
import Data.Array.MArray
import Control.Monad
import Control.Monad.State
{- 目的:ガウス消去法を用いて連立一次方程式を解く -}
solve :: Array (Int, Int) Double -> Array (Int, Int) Double
solve a = runSTArray $ do
b <- thaw a
forward_elimination ((l1,m1),(l2,m2)) b
backward_substitution ((l1,m1),(l2,m2)) b
return b
where ((l1,m1),(l2,m2)) = bounds a
test_solve = r!(0,3) > 2.5999 && r!(0,3) <= 2.6 && r!(1,3) == 3.4 && r!(2,3) == -2.8
where a = array ((0,0),(2,3)) [((0,0),1),((0,1),2),((0,2),3),((0,3),1),
((1,0),-1),((1,1),3),((1,2),2),((1,3),2),
((2,0),2),((2,1),1),((2,2),2),((2,3),3)]
r = solve a
{- 目的:行列aを前進消去する -}
forward_elimination ((l1,m1),(l2,m2)) a = mapM_ elim [(l,i,j)|l<-[l1..l2],
i<-[l1+l+1..l2],
j<-[m2,m2-1..m1+l]]
where elim (l,i,j) = do ij<-readArray a (i,j)
lj<-readArray a (l1+l,j)
im<-readArray a (i,m1+l)
lm<-readArray a (l1+l,m1+l)
writeArray a (i,j) (ij-lj*im/lm)
{- 目的:行列aを後退代入する -}
backward_substitution ((l1,m1),(l2,m2)) a = mapM_ (\i -> do subst1 i
mapM_ (\j -> subst2 i j)
[i-1,i-2..l1])
[l2,l2-1..l1]
where subst1 i = do x<-readArray a (i,m2)
y<-readArray a (i,m2-l2+i-1)
writeArray a (i,m2-l2+i-1) 1
writeArray a (i,m2) (x/y)
subst2 i j = do p<-readArray a (j,m2)
q<-readArray a (j,m2-l2+i-1)
r<-readArray a (i,m2)
writeArray a (j,m2) (p-q*r)
writeArray a (j,m2-l2+i-1) 0
{- 目的:ガウス消去法(部分ピボット選択)を用いて連立一次方程式を解く -}
solve' :: Array (Int, Int) Double -> Array (Int, Int) Double
solve' a = runSTArray $ do
b <- thaw a
forward_elimination' ((l1,m1),(l2,m2)) b
backward_substitution ((l1,m1),(l2,m2)) b
return b
where ((l1,m1),(l2,m2)) = bounds a
test_solve' = r!(0,3) == 1 && r!(1,3) > 3.332 && r!(1,3) < 3.334 && r!(2,3) == 7
where a = array ((0,0),(2,3)) [((0,0),2),((0,1),3),((0,2),-1),((0,3),5),
((1,0),4),((1,1),6),((1,2),-3),((1,3),3),
((2,0),2),((2,1),-3),((2,2),1),((2,3),-1)]
r = solve' a
{- 目的:行列aを前進消去(部分ピボット選択)する -}
forward_elimination' ((l1,m1),(l2,m2)) a = mapM_ (\l -> do swap l
mapM_ (\(i,j) -> elim i j l)
[(i,j)|i<-[l1+l+1..l2],
j<-[m2,m2-1..m1+l]])
[l1..l2]
where pivot l = do max<-readArray a (l1+l,m1+l)
(p,_)<-foldM (\(p1,v1) p2 -> do v2<-readArray a (l1+l+p1,m1+l)
return (if v1>v2&&v1/=0
then (p1,v1) else (p2,v2)))
(0,max) [1..l2-l1-l]
return p
swap l = do p<-pivot l
mapM_ (\m -> swap' p l m) [m2,m2-1..m1+l]
swap' p l m = do ij <-readArray a (l1+l,m)
ij'<-readArray a (l1+l+p,m)
writeArray a (l1+l,m) (ij')
writeArray a (l1+l+p,m) (ij)
elim i j n = do ij <- readArray a (i,j)
lj <- readArray a (l1+n,j)
im <- readArray a (i,m1+n)
lm <- readArray a (l1+n,m1+n)
writeArray a (i,j) (ij-lj*im/lm)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment