Created
August 6, 2017 13:45
-
-
Save mhuesch/dc75478547c56d8d49e1730070588187 to your computer and use it in GitHub Desktop.
Motivated by Andrew Ng's Machine Learning Coursera course, I used AD to minimize the cost function and compared it to an analytic form
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Main where | |
import Numeric.AD.Newton | |
main = do | |
let n = 500 | |
putStrLn $ "gradientDescent after " ++ show n ++ " iterations" | |
print (gradDesc !! n) | |
putStrLn $ "analytic form" | |
print (analyticFit pairs) | |
-- Example (x,y) pairs | |
pairs :: (Num a) => [(a,a)] | |
pairs = [ (1,2) | |
, (2,1) | |
, (3,7) | |
, (-1,10000) | |
, (0,-2) | |
, (-11,-101) | |
] | |
-------------------------------------------------------------------------------- | |
-- AD gradient descent | |
-------------------------------------------------------------------------------- | |
-- Formula for a line | |
h th0 th1 x = th0 + th1*x | |
-- Sum of squared errors between above line and a set of (x,y) pairs | |
j pairs th0 th1 = let m = fromIntegral (length pairs) | |
in (1 / (2*m)) * (sum (map (\(x,y) -> ((h th0 th1 x) - y)^2) pairs)) | |
gradDesc = gradientDescent (\[th0, th1] -> j pairs th0 th1) [1,1] | |
-------------------------------------------------------------------------------- | |
-- Analytic form | |
-------------------------------------------------------------------------------- | |
-- |Given a list of points, return (b,a) of best fit line y=ax+b | |
analyticFit :: Fractional a => [(a,a)] -> (a,a) | |
analyticFit ps = (b,a) | |
where | |
xs = map fst ps | |
ys = map snd ps | |
xBar = average xs | |
yBar = average ys | |
a = sum (zipWith (\xi yi -> (xi-xBar)*(yi-yBar)) xs ys) / sum (map (\xi -> (xi-xBar)^2) xs) | |
b = yBar - a*xBar | |
average :: Fractional a => [a] -> a | |
average xs = sum xs / fromIntegral (length xs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment