Skip to content

Instantly share code, notes, and snippets.

@mrkgnao
Last active July 30, 2022 15:00
Show Gist options
  • Star 13 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save mrkgnao/a45059869590d59f05100f4120595623 to your computer and use it in GitHub Desktop.
Save mrkgnao/a45059869590d59f05100f4120595623 to your computer and use it in GitHub Desktop.
A quick Idris implementation of @mstksg's "dependent Haskell" neural networks
module Main
import Data.Vect
-- %hide transpose
dot : Num a => Vect n a -> Vect n a -> a
dot va vb = foldr (+) 0 $ zipWith (*) va vb
Matrix : (rows : Nat) -> (cols : Nat) -> Type -> Type
Matrix r c a = Vect r (Vect c a)
data Layer : Nat -> Nat -> Type -> Type where
MkLayer : (biases : Vect o a)
-> (weights : Matrix o i a)
-> Layer i o a
infixr 5 :>:
data Network : Nat -> List Nat -> Nat -> Type -> Type where
Output : Layer i o a
-> Network i [] o a
(:>:) : Layer i h a
-> Network h hs o a
-> Network i (h :: hs) o a
infixl 9 .*
(.*) : Num a => Matrix m n a -> Vect n a -> Vect m a
mat .* vec = map (dot vec) mat
Num a => Num (Vect n a) where
(+) = liftA2 (+)
(*) = liftA2 (*)
fromInteger {n} = replicate n . fromInteger
interface Scalable a where
scale : Double -> a -> a
Scalable Double where
scale = (*)
Scalable a => Scalable (Vect n a) where
scale lambda = map (scale lambda)
infixl 9 #*#
(#*#) : Num a => Matrix i j a -> Matrix j k a -> Matrix i k a
A #*# B = map (\Aj => map (dot Aj) (transpose B)) A
Neg a => Neg (Vect n a) where
(-) = liftA2 (-)
negate = map negate
abs = map abs
sigmoidD : Double -> Double
sigmoidD a = 1 / (1 + exp (-a))
sigmoidD' : Double -> Double
sigmoidD' a = let s = sigmoidD a
in s * (1 - s)
sigmoid : Vect n Double -> Vect n Double
sigmoid = map sigmoidD
sigmoid' : Vect n Double -> Vect n Double
sigmoid' = map sigmoidD'
calc : Vect i Double
-> Vect o Double
-> Matrix o i Double
-> Vect o Double
calc input bias weights = weights .* input + bias
runLayer : Vect i Double
-> Layer i o Double
-> Vect o Double
runLayer input (MkLayer bias weights) = calc input bias weights
runLayerS : Vect i Double
-> Layer i o Double
-> Vect o Double
runLayerS input layer = sigmoid $ runLayer input layer
feedForward : Vect i Double
-> Network i hs o Double
-> Vect o Double
feedForward input (l :>: ls) = let input' = (runLayerS input l)
in feedForward input' ls
feedForward input (Output layer) = runLayerS input layer
outer : Num a
=> Vect m a
-> Vect n a
-> Matrix m n a
outer vm vn = (transpose [vm]) #*# [vn]
predictionError : Vect i Double
-> Vect o Double
-> Network i hs o Double
-> Vect o Double
predictionError input target net = target - (feedForward input net)
backprop : Double
-> Vect i Double
-> Vect o Double
-> Network i hs o Double
-> Network i hs o Double
backprop eta input target net = fst (go input target net)
where
go : Vect i Double
-> Vect o Double
-> Network i hs o Double
-> (Network i hs o Double, Vect i Double)
go input target (layer@(MkLayer bias weights) :>: rest) =
let y = runLayer input layer
output = sigmoid y
(rest', dWs') = go output target rest
dEdy = (sigmoid' y) * dWs'
--
bias' = bias - (eta `scale` dEdy)
weights' = weights - (eta `scale` (outer dEdy input))
layer' = (MkLayer bias' weights')
dWs = (transpose weights) .* dEdy
in (layer' :>: rest', dWs)
go input target (Output layer@(MkLayer bias weights)) =
let y = runLayer input layer
output = sigmoid y
dEdy = (sigmoid' y) * (output - target)
--
bias' = bias - (eta `scale` dEdy)
weights' = weights - (eta `scale` (outer dEdy input))
layer' = (MkLayer bias' weights')
dWs = (transpose weights) .* dEdy
in (Output layer', dWs)
initialNet : Network 2 [2] 2 Double
initialNet = first :>: second
where first = MkLayer [0.35, 0.35]
[ [0.15, 0.20]
, [0.25, 0.30]
]
second = Output
$ MkLayer [0.60, 0.60]
[ [0.40, 0.45]
, [0.50, 0.55]
]
input : Vect 2 Double
input = [0.05,0.10]
-- should be "target", but meh
output : Vect 2 Double
output = [0.01,0.99]
main : IO ()
main =
let step = backprop 0.5 input output
errorF = predictionError input output
states = iterate step initialNet
in putStrLn . unlines $ map (show . errorF) (take 100 states)
{-
[-0.7413650695523157, 0.2170715346785375]
[-0.7184417622337655, 0.2116230796990295]
[-0.693669009336258, 0.2064910449028144]
[-0.6672418233176436, 0.2016413051023132]
[-0.6394686816095002, 0.1970437287878285]
[-0.6107602378741731, 0.1926721758900757]
[-0.5815991457088593, 0.1885044535893083]
[-0.5524951990947665, 0.1845221406513488]
[-0.5239353912308239, 0.1807102331941322]
[-0.4963403802849389, 0.1770566276112577]
[-0.4700361147284623, 0.1735515109102654]
[-0.4452436859544368, 0.1701867509351925]
[-0.4220849349263245, 0.1669553649022958]
[-0.4005982141168513, 0.1638511093542886]
[-0.3807583637710784, 0.1608681981538362]
[-0.3624963957286282, 0.1580011304078014]
[-0.3457163205239238, 0.1552445998815962]
[-0.3303081794851857, 0.152593457729155]
[-0.3161573737586014, 0.1500427059155249]
[-0.3031508774128172, 0.1475875055225604]
[-0.2911810603416536, 0.1452231900792029]
[-0.2801477961784186, 0.1429452784460385]
[-0.2699594052550016, 0.140749484669611]
[-0.2605328461773582, 0.1386317239463343]
[-0.2517934502183678, 0.1365881147671741]
[-0.2436743990158697, 0.1346149777502912]
[-0.2361160771575853, 0.1327088318188883]
[-0.2290653827858677, 0.1308663883801339]
[-0.222475046421264, 0.1290845440890528]
[-0.2163029864420817, 0.1273603726841057]
[-0.2105117156600286, 0.1256911162825736]
[-0.2050678046903731, 0.1240741764348239]
[-0.199941402554362, 0.1225071051611767]
[-0.1951058119511725, 0.1209875961337917]
[-0.1905371150740613, 0.1195134761175647]
[-0.1862138451756125, 0.1180826967465477]
[-0.1821166989543435, 0.1166933266839454]
[-0.1782282850105121, 0.1153435441924597]
[-0.1745329039580395, 0.1140316301261325]
[-0.1710163561922812, 0.1127559613435504]
[-0.1676657737459223, 0.111515004534339]
[-0.1644694730863394, 0.1103073104454442]
[-0.1614168261005721, 0.1091315084901476]
[-0.1584981468707465, 0.1079863017206288]
[-0.1557045921609243, 0.1068704621437445]
[-0.1530280738165937, 0.1057828263593242]
[-0.1504611815227601, 0.104722291500412]
[-0.147997114579025, 0.1036878114553911]
[-0.1456296215336469, 0.1026783933526751]
[-0.1433529466768198, 0.1016930942895522]
[-0.1411617825295137, 0.1007310182877622]
[-0.1390512275811833, 0.09979131345942471]
[-0.1370167486300981, 0.09887316936797841]
[-0.1350541471663034, 0.09797581456982185]
[-0.1331595293113357, 0.09709851432334826]
[-0.1313292788925082, 0.09624056845301787]
[-0.129560033284393, 0.09540130935702407]
[-0.1278486616973177, 0.0945801001479607]
[-0.12619224563339, 0.09377633291670329]
[-0.1245880612656923, 0.09298942711045954]
[-0.1230335635266569, 0.09221882801664594]
[-0.1215263717179235, 0.09146400534488552]
[-0.1200642564767705, 0.09072445190002432]
[-0.1186451279540033, 0.08999968233960876]
[-0.1172670250753927, 0.08928923200977779]
[-0.115928105773743, 0.08859265585399401]
[-0.1146266380917467, 0.08790952738946012]
[-0.113360992067203, 0.08723943774647369]
[-0.112129632322174, 0.08658199476633321]
[-0.1109311112864131, 0.08593682215373866]
[-0.1097640629930862, 0.08530355867995043]
[-0.1086271973915681, 0.08468185743323953]
[-0.1075192951280463, 0.0840713851134347]
[-0.1064392027499095, 0.08347182136760178]
[-0.1053858282945339, 0.08288285816411867]
[-0.1043581372271753, 0.0823041992026059]
[-0.1033551486963057, 0.0817355593573641]
[-0.102375932077947, 0.08117666415213609]
[-0.101419603783417, 0.08062724926417586]
[-0.1004853243074385, 0.08008706005574517]
[-0.09957229549583126, 0.07955585113129793]
[-0.09867975801401846, 0.07903338591873366]
[-0.09780698899938538, 0.07851943627321689]
[-0.09695329988213586, 0.07801378210216214]
[-0.09611803436073543, 0.0775162110100851]
[-0.09530056651932017, 0.07702651796210702]
[-0.09450029907561171, 0.07654450496498333]
[-0.09371666174891867, 0.0760699807646058]
[-0.09294910973874204, 0.07560276055899551]
[-0.09219712230534606, 0.07514266572587369]
[-0.09146020144441579, 0.07468952356395442]
[-0.09073787064860862, 0.07424316704715994]
[-0.09002967374942897, 0.07380343459101679]
[-0.08933517383341218, 0.07337016983053202]
[-0.08865395222711617, 0.0729432214088982]
[-0.08798560754587555, 0.07252244277641773]
[-0.08732975480169207, 0.07210769199907441]
[-0.08668602456601479, 0.07169883157621615]
[-0.08605406218350631, 0.07129572826684816]
[-0.08543352703320682, 0.07089825292406515]
-}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment