Last active
September 24, 2015 01:42
-
-
Save gatlin/e81b8c572b2f284f1423 to your computer and use it in GitHub Desktop.
Simple neural network with backpropagation in Haskell, using Repa. Inspired by: http://iamtrask.github.io/2015/07/12/basic-python-network/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{- | |
To run: | |
1. Ensure you have repa and repa-algorithms installed | |
2. ghc --make nn_repa.hs -O2 | |
3. ./NN +RTS -s | |
The `+RTS -s` part prints out a summary of runtime statistics. | |
-} | |
import Control.Monad (forM_) | |
import Data.Array.Repa as R hiding ((++)) | |
import Data.Array.Repa.Algorithms.Matrix (mmultS, transpose2S) | |
import Data.Array.Repa.Algorithms.Randomish (randomishDoubleArray) | |
import Data.IORef | |
type Matrix a = Array a DIM2 | |
type Two a = (a,a) -- Kept a type signature shorter | |
-- | Convenience wrapper for generating random-ish arrays | |
randomArray :: Int {- ^ Rows -} -> Int {- ^ Columns -} -> Matrix U Double | |
randomArray rows cols = computeS $ R.map (\x -> 2 * x - 1) $ | |
randomishDoubleArray (Z :. rows :. cols) 0 1 100 | |
-- | Test input data | |
x :: Matrix U Double | |
x = fromListUnboxed (Z:.4:.3) | |
[ 0, 0, 1 | |
, 0, 1, 1 | |
, 1, 0, 1 | |
, 1, 1, 1 ] | |
-- | Expected output | |
y :: Matrix U Double | |
y = fromListUnboxed (Z:.4:.1) [ 0, 1, 1, 0 ] | |
-- | Train the synapses (weights) of a 3-layer network | |
train :: Matrix U Double -- ^ Input matrix | |
-> Matrix U Double -- ^ Expected output matrix | |
-> Int -- ^ Number of iterations to run | |
-> IO (Two (Matrix U Double)) -- ^ weight synapses | |
train _in _ex n = do | |
s0_ref <- newIORef $ randomArray 3 4 -- - Create two mutable references | |
s1_ref <- newIORef $ randomArray 4 1 -- / | |
forM_ [1..n] $ \j -> do | |
syn0 <- readIORef s0_ref | |
syn1 <- readIORef s1_ref | |
let l1 = computeS $ R.map (1/) (R.map (1+) (R.map exp (R.map ((-1)*) (mmultS _in syn0)))) | |
let l2 = R.map (1/) (R.map (1+) (R.map exp (R.map ((-1)*) (mmultS l1 syn1)))) | |
let l2_delta = computeS $ R.zipWith (*) (R.zipWith (-) _ex l2) | |
(R.zipWith (*) l2 (R.map (1-) l2)) | |
let l1_delta = computeS $ R.zipWith (*) (mmultS l2_delta (transpose2S syn1)) | |
(R.zipWith (*) l1 (R.map (1-) l1)) | |
modifyIORef' s1_ref $ \s1 -> computeS $ R.zipWith (+) s1 | |
(mmultS (transpose2S l1) l2_delta) | |
modifyIORef' s0_ref $ \s0 -> computeS $ R.zipWith (+) s0 | |
(mmultS (transpose2S _in) l1_delta) | |
syn0 <- readIORef s0_ref | |
syn1 <- readIORef s1_ref | |
return (syn0, syn1) | |
-- | Run a network with the given synapses and inputs | |
run :: Two (Matrix U Double) | |
-> Matrix U Double | |
-> Matrix U Double | |
run (syn0, syn1) _in = | |
let l1 = computeS $ R.map (1/) (R.map (1+) (R.map exp (R.map ((-1)*) (mmultS _in syn0)))) | |
in computeS $ R.map (1/) (R.map (1+) (R.map exp (R.map ((-1)*) (mmultS l1 syn1)))) | |
main :: IO () | |
main = do | |
syns <- train x y 60000 | |
let results = run syns x | |
putStrLn $ "Results: " ++ (show results) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment