Skip to content

Instantly share code, notes, and snippets.

@avarsh
Created October 18, 2020 21:51
Show Gist options
  • Save avarsh/58373df585b6ef64a36e2c4b90c85206 to your computer and use it in GitHub Desktop.
Save avarsh/58373df585b6ef64a36e2c4b90c85206 to your computer and use it in GitHub Desktop.
module Main where
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
import Data.Array.Accelerate as A hiding ((!!), length)
import Data.Array.Accelerate.Debug as A
import Data.Array.Accelerate.Numeric.LinearAlgebra as A
import Data.Array.Accelerate.LLVM.Native as CPU
import Control.DeepSeq
import Text.Printf
import Prelude as P
-- Utilities
(^+^) :: (Shape sh, P.Num (Exp c), Elt c) => Acc (Array sh c) -> Acc (Array sh c) -> Acc (Array sh c)
u ^+^ v = A.zipWith (+) u v
(^-^) :: (Shape sh, P.Num (Exp c), Elt c) => Acc (Array sh c) -> Acc (Array sh c) -> Acc (Array sh c)
u ^-^ v = A.zipWith (-) u v
(^*^) :: (Shape sh, P.Num (Exp c), Elt c) => Acc (Array sh c) -> Acc (Array sh c) -> Acc (Array sh c)
u ^*^ v = A.zipWith (*) u v
(*^) :: forall sh a. (Shape sh, Elt a, P.Num (Exp a)) => Exp a -> Acc (Array sh a) -> Acc (Array sh a)
s *^ v = A.map (\x -> x * s) v
type Activation = Exp Double -> Exp Double
sigmoid :: Activation
sigmoid = \z -> 1.0 / (1.0 + exp (-z))
sigmoid' :: Exp Double -> Exp Double
sigmoid' = \z -> sigmoid z * (1 - sigmoid z)
data BasicNetwork = BasicNetwork [Acc (Matrix Double)] [Acc (Vector Double)]
deriving Show
create :: [Int] -> BasicNetwork
create xs = BasicNetwork weights biases
where
weights :: [Acc (Matrix Double)]
weights = do
idx <- [1..(length xs - 1)]
pure $ use $ (fromList ( Z :. xs!!idx :. xs!!(idx - 1) ) [1..] :: Matrix Double)
biases :: [Acc (Vector Double)]
biases = do
idx <- [1..(length xs - 1)]
pure $ use $ (fromList (Z :. xs!!idx) [1..] :: Vector Double)
feedforward :: BasicNetwork -> Acc (Vector Double) -> Acc (Vector Double)
feedforward (BasicNetwork ws bs) input = res
where
res = feedforward' ws bs input
feedforward' :: [ Acc (Matrix Double) ] -> [ Acc (Vector Double) ] -> Acc (Vector Double) -> Acc (Vector Double)
feedforward' [] [] a = a
feedforward' (w:ws) (b:bs) a = feedforward' ws bs $ A.map sigmoid $ (w #> a) ^+^ b
type TrainingData = [ (Acc (Vector Double), Acc (Vector Double)) ]
train :: BasicNetwork -> Int -> TrainingData -> Int -> Double -> BasicNetwork
train net _ _ 0 _ = net
train net n td epochs eta = train net' n td (epochs - 1) eta
where
BasicNetwork weights biases = net
net' = BasicNetwork (P.map (use . CPU.run) weights') (P.map (use . CPU.run) biases')
nablaB :: [ Acc (Vector Double) ]
nablaW :: [ Acc (Matrix Double) ]
(nablaB, nablaW) = descend td
-- training data biases weights for each layer
descend :: TrainingData -> ([Acc (Vector Double)], [Acc (Matrix Double)])
descend [(x, y)] = backprop x y net
descend ((x, y):td') = (nablaB', nablaW')
where
(nablaB, nablaW) = descend td'
(deltaNablaB, deltaNablaW) = backprop x y net
nablaB' = [ nb ^+^ dnb | (nb, dnb) <- P.zip nablaB deltaNablaB ]
nablaW' = [ nw ^+^ dnw | (nw, dnw) <- P.zip nablaW deltaNablaW ]
velocity = lift (eta / P.fromIntegral n)
weights' = [w ^-^ (velocity *^ wb) | (w, wb) <- P.zip weights nablaW]
biases' = [b ^-^ (velocity *^ nb) | (b, nb) <- P.zip biases nablaB]
backprop :: Acc (Vector Double) -> Acc (Vector Double) -> BasicNetwork -> ([Acc (Vector Double)], [Acc (Matrix Double)])
backprop actual expected (BasicNetwork ws bs) = (b, w)
where
(b, w) = backprop' (P.tail ws) activations zs
backprop' :: [Acc (Matrix Double)]
-> [Acc (Vector Double)]
-> [Acc (Vector Double)]
-> ([Acc (Vector Double)], [Acc (Matrix Double)])
backprop' [] [a', a] [z] = ([delta], [nw])
where
delta = (cost' a expected) ^*^ (A.map sigmoid' z)
nw = delta >< a'
backprop' (w:ws) (a:a':as) (z:zs) = (delta':delta:xs, y:ys)
where
sp = A.map sigmoid' z
delta' = ((transpose w) #> delta) ^*^ sp
y = delta' >< a
(delta:xs, ys) = backprop' ws (a':as) zs
(activations, zs) = calcActivations actual ws bs
calcActivations x' [] [] = ([x'], [])
calcActivations x' (w:ws) (b:bs) = (x':as, z:zs)
where
(as, zs) = calcActivations x'' ws bs
z = (w #> x') ^+^ b
x'' = A.map sigmoid z
cost' :: Acc (Vector Double) -> Acc (Vector Double) -> Acc (Vector Double)
cost' actual expected = actual ^-^ expected
--------------------------------------------------------------
-- Main.hs
--------------------------------------------------------------
main :: IO ()
main = do
let
input = [[1, 0], [0, 1], [1, 1], [0, 0]]
expected = [[1], [1], [0], [0] ]
xorData = [ (use $ (A.fromList (Z :. 2) x :: A.Vector Double), use $ (A.fromList (Z :. 1) y :: A.Vector Double)) |
(x, y) <- P.zip input expected ]
net = create [2, 2, 1]
net' = train net 4 xorData 100 2
net'' = let BasicNetwork ws bs = net'
in BasicNetwork (P.map (use . CPU.run) ws) (P.map (use . CPU.run) bs)
feedforward' = CPU.runN (feedforward net'')
test = feedforward' (A.fromList (Z:.100) (cycle [0,0,1,0,1,1,0,1]))
r1 = feedforward' ((A.fromList (Z :. 2) [0, 0] :: A.Vector Double))
r2 = feedforward' ((A.fromList (Z :. 2) [1, 0] :: A.Vector Double))
r3 = feedforward' ((A.fromList (Z :. 2) [1, 1] :: A.Vector Double))
r4 = feedforward' ((A.fromList (Z :. 2) [0, 1] :: A.Vector Double))
setFlag dump_phases
putStrLn "== TRAINING ===================================================================="
feedforward' `seq` return ()
putStrLn "== PREDICTION =================================================================="
print $!! test
print $!! r1
print $!! r2
print $!! r3
print $!! r4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment