This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- 0. Download x.data and y.dat from https://www.kaggle.com/masterdezign/iris-with-onehotencoded-targets | |
-- 1. Install stack (command line interface is marked by $): | |
-- $ wget -qO- https://get.haskellstack.org/ | sh | |
-- (alternatively, curl -sSL https://get.haskellstack.org/ | sh) | |
-- 2. Install open-blas from https://www.openblas.net/ (needed for hmatrix package) | |
-- 3. Run | |
-- $ stack --resolver lts-10.6 --install-ghc runghc --package hmatrix-0.18.2.0 Iris.hs | |
import Numeric.LinearAlgebra as LA | |
import qualified Data.Vector.Storable as V | |
-- New weights | |
newW (nin, nout) = do | |
let k = sqrt (1.0 / fromIntegral nin) | |
w <- randn nin nout | |
return (cmap (k *) w) | |
-- Transformations | |
loss y tgt = | |
let diff = y - tgt | |
in sumElements $ cmap (^2) diff | |
sigmoid = cmap f | |
where | |
f x = recip $ 1.0 + exp (-x) | |
-- Their gradients | |
sigmoid' y dY = dY * y1 * (ones - y1) | |
where | |
y1 = sigmoid y | |
ones = (rows y1) >< (cols y1) $ repeat 1.0 | |
linear' x dy = cmap (/ m) (tr' x LA.<> dy) | |
where | |
m = fromIntegral $ rows x | |
loss' y tgt = | |
let diff = y - tgt | |
in cmap (* 2) diff | |
-- Building NN | |
forward x w1 = | |
let h = x LA.<> w1 | |
y = sigmoid h | |
in [h, y] | |
descend gradF iterN gamma x0 = take iterN (iterate step x0) | |
where | |
step x = x - gamma * gradF(x) | |
grad (x, y) w1 = dW1 | |
where | |
[h, y_pred] = forward x w1 | |
dE = loss' y_pred y | |
dY = sigmoid' h dE | |
dW1 = linear' x dY | |
main = do | |
dta <- loadMatrix "x.dat" | |
tgt <- loadMatrix "y.dat" | |
let (nin, nout) = (4, 3) | |
w1_rand <- newW (nin, nout) | |
let epochs = 500 | |
let w1 = last $ descend (grad (dta, tgt)) epochs 0.01 w1_rand | |
[_, y_pred0] = forward dta w1_rand | |
[_, y_pred] = forward dta w1 | |
putStrLn $ "Initial loss " ++ show (loss y_pred0 tgt) | |
putStrLn $ "Loss after training " ++ show (loss y_pred tgt) | |
putStrLn "Some predictions by an untrained network:" | |
print $ takeRows 5 y_pred0 | |
putStrLn "Some predictions by a trained network:" | |
print $ takeRows 5 y_pred | |
putStrLn "Targets" | |
print $ takeRows 5 tgt |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment