Skip to content

Instantly share code, notes, and snippets.

@thesz
Created June 12, 2022 17:12
Show Gist options
  • Save thesz/d89ff18ffdd89037c43637bd79c1cbf3 to your computer and use it in GitHub Desktop.
Save thesz/d89ff18ffdd89037c43637bd79c1cbf3 to your computer and use it in GitHub Desktop.
Simple linear algebra routines, including Fisher's linear discriminant
{-# LANGUAGE BangPatterns #-}
module SLA where
import qualified Data.List as List
import qualified Data.Vector.Unboxed as UV
type Vec = [Double]
type Mat = [Vec] -- column of rows.
inner :: Vec -> Vec -> Double
inner a b = sum $ zipWith (*) a b
transpose :: Mat -> Mat
transpose = List.transpose
mmv :: Mat -> Vec -> Vec
mmv m v = map (inner v) m
outer :: Vec -> Vec -> Mat
outer a b = [[ x * y | x <- a] | y <- b]
infixl 6 .-., .+.
(.+.), (.-.) :: Vec -> Vec -> Vec
(.-.) = zipWith (-)
(.+.) = zipWith (+)
infixl 7 .#*.
(.#*.) :: Double -> Vec -> Vec
a .#*. v = map (a*) v
conjGrad :: Mat -> Vec -> Vec
conjGrad a b = f b b (map (const 0) b)
where
f rk pk xk
| irk < 1e-5 = xk
| otherwise =
f rk1 pk1 xk1
where
irk = inner rk rk
ak = irk / inner pk apk
apk = mmv a pk
xk1 = xk .+. ak .#*. pk
rk1 = rk .-. ak .#*. apk
bk = inner rk1 rk1 / irk
pk1 = rk1 .+. bk .#*. pk
fld :: [Vec] -> [Bool] -> (Vec, Vec)
fld samples classes = (v, map (inner v) samples)
where
sample0 = map (const 0) $ head samples
(mu0, mu1) = averages sample0 sample0 0 0 samples classes
averages s0 s1 !c0 !c1 [] [] = ((1/c0) .#*. s0, (1/c1) .#*. s1)
averages s0 s1 !c0 !c1 (s:ss) (c:cs)
| c = averages s0 (s1 .+. s) c0 (c1+1) ss cs
| otherwise = averages (s0 .+. s) s1 (c0+1) c1 ss cs
dmu = mu0 .-. mu1
sw = computeMatrix (map (const sample0) sample0) samples classes
computeMatrix m (s:ss) (c:cs) = computeMatrix m' ss cs
where
mu = if c then mu1 else mu0
y = s .-. mu
m' = zipWith (.+.) m (outer y y)
computeMatrix m [] [] = m
v = conjGrad sw dmu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment