Created
June 12, 2022 17:12
-
-
Save thesz/d89ff18ffdd89037c43637bd79c1cbf3 to your computer and use it in GitHub Desktop.
Simple linear algebra routines, including Fisher's linear discriminant
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE BangPatterns #-} | |
module SLA where | |
import qualified Data.List as List | |
import qualified Data.Vector.Unboxed as UV | |
type Vec = [Double] | |
type Mat = [Vec] -- column of rows. | |
inner :: Vec -> Vec -> Double | |
inner a b = sum $ zipWith (*) a b | |
transpose :: Mat -> Mat | |
transpose = List.transpose | |
mmv :: Mat -> Vec -> Vec | |
mmv m v = map (inner v) m | |
outer :: Vec -> Vec -> Mat | |
outer a b = [[ x * y | x <- a] | y <- b] | |
infixl 6 .-., .+. | |
(.+.), (.-.) :: Vec -> Vec -> Vec | |
(.-.) = zipWith (-) | |
(.+.) = zipWith (+) | |
infixl 7 .#*. | |
(.#*.) :: Double -> Vec -> Vec | |
a .#*. v = map (a*) v | |
conjGrad :: Mat -> Vec -> Vec | |
conjGrad a b = f b b (map (const 0) b) | |
where | |
f rk pk xk | |
| irk < 1e-5 = xk | |
| otherwise = | |
f rk1 pk1 xk1 | |
where | |
irk = inner rk rk | |
ak = irk / inner pk apk | |
apk = mmv a pk | |
xk1 = xk .+. ak .#*. pk | |
rk1 = rk .-. ak .#*. apk | |
bk = inner rk1 rk1 / irk | |
pk1 = rk1 .+. bk .#*. pk | |
fld :: [Vec] -> [Bool] -> (Vec, Vec) | |
fld samples classes = (v, map (inner v) samples) | |
where | |
sample0 = map (const 0) $ head samples | |
(mu0, mu1) = averages sample0 sample0 0 0 samples classes | |
averages s0 s1 !c0 !c1 [] [] = ((1/c0) .#*. s0, (1/c1) .#*. s1) | |
averages s0 s1 !c0 !c1 (s:ss) (c:cs) | |
| c = averages s0 (s1 .+. s) c0 (c1+1) ss cs | |
| otherwise = averages (s0 .+. s) s1 (c0+1) c1 ss cs | |
dmu = mu0 .-. mu1 | |
sw = computeMatrix (map (const sample0) sample0) samples classes | |
computeMatrix m (s:ss) (c:cs) = computeMatrix m' ss cs | |
where | |
mu = if c then mu1 else mu0 | |
y = s .-. mu | |
m' = zipWith (.+.) m (outer y y) | |
computeMatrix m [] [] = m | |
v = conjGrad sw dmu |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment