Last active
November 5, 2023 21:21
-
-
Save BartoszMilewski/2e2f60dd4abfdda86dc678d1ee85f487 to your computer and use it in GitHub Desktop.
Neural Networks in Idris
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module HVect | |
import Data.Vect | |
-- Heterogeneous vector | |
public export | |
data HVect : Vect n Type -> Type where | |
Nil : HVect Nil | |
(::) : h -> HVect t -> HVect (h :: t) | |
export | |
Show (HVect []) where | |
show Nil = "\n" | |
export | |
(Show t, Show (HVect ts)) => Show (HVect (t :: ts)) where | |
show (x :: xs) = show x ++ " :: " ++ show xs | |
export | |
Semigroup (HVect []) where | |
[] <+> [] = [] | |
export | |
(Semigroup t, Semigroup (HVect ts)) => | |
Semigroup (HVect (t :: ts)) where | |
(a :: as) <+> (b :: bs) = (a <+> b) :: (as <+> bs) | |
export | |
Monoid (HVect []) where | |
neutral = [] | |
export | |
(Monoid t, Monoid (HVect ts)) => Monoid (HVect (t :: ts)) where | |
neutral = neutral :: neutral | |
public export | |
interface (Semigroup v, Monoid v) => VSpace v where | |
scale : Double -> v -> v | |
export | |
VSpace (HVect []) where | |
scale a Nil = Nil | |
-- Replicate a vector of types | |
-- map (Vect k) ts | |
public export | |
ReplTypes : (k : Nat) -> (ts : Vect l Type) -> Vect l Type | |
ReplTypes k [] = [] | |
ReplTypes k (t' :: ts') = Vect k t' :: ReplTypes k ts' | |
-- Concatenate vectors of heterogeneous monoid types | |
export | |
concatH : {k : Nat} -> {l : Nat} -> {ts : Vect l Type} -> {isMono : HVect (map Monoid ts)} -> | |
HVect (ReplTypes k ts) -> HVect ts | |
concatH {l = 0} {ts = []} Nil = Nil | |
concatH {ts = t' :: ts'} {isMono = (_ :: pfs)} (v :: vs) = concat v :: concatH {isMono = pfs} vs | |
export | |
emptyVTypes : {l : Nat} -> (ts : Vect l Type) -> HVect (ReplTypes 0 ts) | |
emptyVTypes [] = Nil | |
emptyVTypes (t' :: ts') = [] :: emptyVTypes ts' | |
-- Generalization of zipWith (::) | |
export | |
zipCons : {k : Nat} -> {l : Nat} -> {ts : Vect l Type} -> | |
HVect ts -> HVect (ReplTypes k ts) -> HVect (ReplTypes (S k) ts) | |
zipCons [] [] = [] | |
zipCons (t' :: ts') (vs :: vss) = (t' :: vs) :: zipCons ts' vss | |
-- Transpose a vector whose entries are heterogeneous vectors | |
export | |
transposeH : {k : Nat} -> {l : Nat} -> {ts : Vect l Type} -> | |
Vect k (HVect ts) -> HVect (ReplTypes k ts) | |
transposeH {k=0} {ts} [] = emptyVTypes ts | |
transposeH (h :: hs) = zipCons h (transposeH hs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module NNet | |
import Data.Vect | |
import HVect | |
-- Vector of Double | |
V : (n : Nat) -> Type | |
V n = Vect n Double | |
-- Parameters for an affine lens of | |
-- m inputs, one output | |
record Para (m : Nat) where | |
constructor MkPara | |
weight : V m | |
bias : Double | |
-- vector of n parameters, each m wide | |
-- (n * m inputs, n outputs) | |
ParaBlock : (m : Nat) -> (n : Nat) -> Type | |
ParaBlock m n = Vect n (Para m) | |
-- Chain of parameter blocks | |
-- Parameters for multi-layer perceptron with m inputs | |
-- A chain of types ParaBlock m n1 :: ParaBlock n1 n2 :: ParaBlock n2 n3 ... | |
ParaChain : (m : Nat) -> (ns : Vect l Nat) -> Vect l Type | |
ParaChain m [] = [] | |
ParaChain m (n' :: ns') = ParaBlock m n' :: ParaChain n' ns' | |
-- Chain of vectors of parameter blocks (for batches of perceptrons) | |
VParaChain : (k : Nat) -> (m : Nat) -> (ns : Vect l Nat) -> Vect l Type | |
VParaChain k m ns = ReplTypes k (ParaChain m ns) | |
------------- | |
-- Interfaces | |
------------- | |
{m : Nat} -> Show (Para m) where | |
show pa = "weight: " ++ show (weight pa) ++ " bias: " ++ show (bias pa) ++ "\n" | |
-- Semigroup | |
Semigroup Double where | |
x <+> y = x + y | |
Semigroup (Para m) where | |
(MkPara w b) <+> (MkPara w' b') = MkPara (zipWith (+) w w') (b + b') | |
-- Monoid | |
Monoid Double where | |
neutral = 0.0 | |
{m : Nat} -> Monoid (Para m) where | |
neutral = MkPara (replicate m 0.0) 0.0 | |
-- Proof that every type in ParaChain is a Monoid | |
isMonoChain : {m : Nat} -> (ns : Vect l Nat) -> HVect (map Monoid (ParaChain m ns)) | |
isMonoChain Nil = Nil | |
isMonoChain (n' :: ns') = MkMonoid neutral :: isMonoChain ns' | |
-- Vector Space | |
{m : Nat} -> VSpace (Para m) where | |
scale a (MkPara w b) = MkPara (map (a *) w) (a * b) | |
{m : Nat} -> {n : Nat} -> VSpace (ParaBlock m n) where | |
scale a v = map (scale a) v | |
-- in (ParaChain m ns), all types are Monoid | |
collectH : {k : Nat} -> {m : Nat} -> {l : Nat} -> {ns : Vect l Nat} -> HVect (VParaChain k m ns) -> HVect (ParaChain m ns) | |
collectH hv = concatH {isMono = isMonoChain ns} hv | |
collectParas : {k : Nat} -> {m : Nat} -> {l : Nat} -> {ns : Vect l Nat} -> | |
Vect k (HVect (ParaChain m ns)) -> HVect (ParaChain m ns) | |
collectParas = collectH . transposeH | |
---------- | |
-- Parametric lens | |
-- record PLens p p' s s' a a' | |
-- fwd : (p, s) -> a | |
-- lens1.bwd : (p, s, a') -> (p', s') | |
record PLens p s a where | |
constructor MkPLens | |
fwd : (p, s) -> a | |
bwd : (p, s, a) -> (p, s) | |
-- Special case of parametric lens with p = () | |
-- Simplifies composition | |
record Lens s a where | |
constructor MkLens | |
fwd0 : s -> a | |
bwd0 : (s, a) -> s | |
-- Composition of parametric lenses | |
compose : PLens p s a -> PLens q a b -> | |
PLens (p, q) s b | |
-- lens1.fwd : (p, s) -> a | |
-- lens1.bwd : (p, s, a) -> (p, s) | |
-- lens2.fwd : (q, a) -> b | |
-- lens2.bwd : (q, a, b) -> (q, a) | |
compose lens1 lens2 = MkPLens fwd' bwd' | |
where | |
fwd' : ((p, q), s) -> b | |
fwd' ((p, q), s) = lens2.fwd (q, lens1.fwd (p, s)) | |
bwd' : ((p, q), s, b) -> ((p, q), s) | |
bwd' ((p, q), s, b) = | |
let (q', a') = lens2.bwd (q, lens1.fwd (p, s), b) | |
(p', s') = lens1.bwd (p, s, a') | |
in ((p', q'), s') | |
-- Helpers for composing a parametric lens with a non-parametric one | |
composeR : PLens p s a -> Lens a b -> | |
PLens p s b | |
composeR lens1 lens2 = MkPLens fwd' bwd' | |
where | |
fwd' : (p, s) -> b | |
fwd' (p, s) = lens2.fwd0 (lens1.fwd (p, s)) | |
bwd' : (p, s, b) -> (p, s) | |
bwd' (p, s, b) = | |
let a' = lens2.bwd0 (lens1.fwd (p, s), b) | |
(p', s') = lens1.bwd (p, s, a') | |
in (p', s') | |
composeL : Lens s a -> PLens p a b -> | |
PLens p s b | |
-- lens1.fwd : s -> a | |
-- lens1.bwd : (s, a) -> s | |
-- lens2.fwd : (p, a) -> b | |
-- lens2.bwd : (p, a, b) -> (p, a) | |
composeL lens1 lens2 = MkPLens fwd' bwd' | |
where | |
fwd' : (p, s) -> b | |
fwd' (p, s) = lens2.fwd (p, lens1.fwd0 s) | |
bwd' : (p, s, b) -> (p, s) | |
bwd' (p, s, b) = | |
let (p', a') = lens2.bwd (p, lens1.fwd0 s, b) | |
s' = lens1.bwd0 (s, a') | |
in (p', s') | |
-- Product of parametric lenses, | |
prodLens : | |
PLens p s a -> | |
PLens p' s' a' -> | |
PLens (p, p') (s, s') (a, a') | |
-- lens1.fwd : (p, s) -> a | |
-- lens1.bwd : (p, s, a) -> (p, s) | |
prodLens lens1 lens2 = | |
MkPLens fwdProd bwdProd | |
where | |
fwdProd : ((p, p'), (s, s')) -> (a, a') | |
fwdProd ((p, p'), (s, s')) = (lens1.fwd (p, s), lens2.fwd (p', s')) | |
bwdProd : ((p, p'), (s, s'), (a, a')) -> ((p, p'), (s, s')) | |
bwdProd ((p, p'), (s, s'), (a, a')) = | |
let (q, t) = lens1.bwd (p, s, a) | |
(q', t') = lens2.bwd (p', s', a') | |
in ((q, q'), (t, t')) | |
-- duplicate a lens in parallel n+1 times | |
vecLens : (n : Nat) -> PLens p s a -> PLens (Vect n p) (Vect n s) (Vect n a) | |
vecLens Z _ = MkPLens (\(Nil, Nil) => Nil) (\(Nil, Nil, Nil) => (Nil, Nil)) | |
vecLens (S n) lns = MkPLens fwd' bwd' | |
where | |
lnsN : PLens (Vect n p) (Vect n s) (Vect n a) | |
lnsN = vecLens n lns | |
fwd' : (Vect (S n) p, Vect (S n) s) -> Vect (S n) a | |
fwd' (p :: ps, s :: ss) = lns.fwd (p, s) :: lnsN.fwd (ps, ss) | |
bwd' : (Vect (S n) p, Vect (S n) s, Vect (S n) a) -> (Vect (S n) p, Vect (S n) s) | |
bwd' (p :: ps, s :: ss, a :: as) = | |
let (p', s') = lns.bwd (p, s, a) | |
(ps', ss') = lnsN.bwd (ps, ss, as) | |
in (p' :: ps', s' :: ss') | |
-- A branching combinator | |
branch : Monoid s => (n : Nat) -> Lens s (Vect n s) | |
branch n = MkLens (replicate n) (\(_, ss) => concat ss) -- pointwise <+> | |
-- Batch n lenses in parallel sharing the same parameters | |
-- input and output are n-tupled, parameters are collected | |
batch : Monoid p => | |
(n : Nat) -> | |
PLens p s a -> | |
PLens p (Vect n s) (Vect n a) | |
batch n lns = | |
MkPLens fwdB bwdB | |
where | |
fwdB : (p, Vect n s) -> Vect n a | |
fwdB (p, ss) = map lns.fwd (zip (replicate n p) ss) | |
bwdB : (p, Vect n s, Vect n a) -> (p, Vect n s) | |
bwdB (p, ss, as) = | |
let (ps', ss') = unzip $ map lns.bwd $ zip3 (replicate n p) ss as | |
in (concat ps', ss') | |
------------------------------------- | |
------- Vector parametric lenses ---- | |
------------------------------------- | |
-- activation lens using tanh (no parameters) | |
activ : Lens Double Double | |
activ = MkLens (\s => tanh s) | |
(\(s, a) => a * (1 - (tanh s)*(tanh s))) -- a * da/ds | |
-- Affine parametric lens (a composition of linear and bias) | |
affine : (m : Nat) -> PLens (Para m) (V m) Double | |
affine n = MkPLens fwd' bwd' | |
where | |
fwd' : (Para m, V m) -> Double | |
fwd' (p, s) = foldl (+) (bias p) (zipWith (*) (weight p) s) -- a = b + w * s | |
bwd' : (Para m, V m, Double) -> (Para m, V m) | |
bwd' (p, s, a) = ( MkPara (map (a *) s) a -- (da/dw, da/db) | |
, map (a *) (weight p)) -- da/ds | |
-- Neuron with m inputs and one output | |
-- affine : PLens (Para m) (V m) Double | |
-- activ : Lens Double Double | |
-- composite : PLens (Para m) (V m) Double | |
neuron : (m : Nat) -> PLens (Para m) (V m) Double | |
neuron m = composeR (affine m) activ | |
-- A layer of neurons | |
-- n neurons with m inputs each | |
-- 1 2 .. n | |
-- | | | | |
-- m m m | |
-- \ / \ / | |
-- m | |
-- ParaBlock m n = Vect n (Para m) | |
-- neuron m : PLens (Vect n (Para 1)) (V m) (V 1) | |
-- vecLens n (neuron m): PLens (Vect n (Vect n (Para 1))) (Vect n (V m)) (Vect n (V 1)) | |
-- branch n : Lens (V m) (Vect n (V m)) | |
-- s a | |
-- composeL : Lens s a -> PLens p a b -> PLens p s b | |
layer : (n : Nat) -> (m : Nat) -> PLens (Vect n (Para m)) (V m) (V n) | |
layer n m = composeL (branch n) (vecLens n (neuron m)) | |
-- m -> [m, n1] -> [n1, n2] -> ... [n l, n (l+1)] | |
-- Multi layer perceptron with m inputs and l+1 layers | |
-- neuron count in each layer is given by (Vect l Nat) | |
-- 1 2 .. n2 [n2] | |
-- n1 n1 n1 | |
-- |/ \|/ \| | |
-- 1 2 .. n1 [n1] <-n1- [P[m], P[m] .. P[m]] | |
-- m m m | |
-- \ / \ / | |
-- m | |
makeMLP : (m : Nat) -> {l : Nat} -> (ls : Vect (S l) Nat) -> -- << architecture | |
PLens (HVect (ParaChain m ls)) (V m) (V (last ls)) | |
makeMLP m (n :: []) = MkPLens fwd' bwd' | |
where | |
lr : PLens (ParaBlock m n) (V m) (V n) | |
lr = layer n m | |
fwd' : (HVect (ParaChain m [n]), V m) -> V (n) | |
fwd' ([p], v) = lr.fwd (p, v) | |
bwd' : (HVect (ParaChain m [n]), V m, V n) -> (HVect (ParaChain m [n]), V m) | |
bwd' ([p], v, w) = let (p', v') = lr.bwd (p, v, w) | |
in ([p'], v') | |
makeMLP m (n1 :: n2 :: ns) = MkPLens fwd' bwd' | |
where | |
-- m -> [m, n1] -> [n1, n2] -> ... [n l, n (l+1)] | |
mlp : PLens (ParaBlock m n1, HVect (ParaChain n1 (n2 :: ns))) | |
(V m) | |
(V (last (n2 :: ns))) | |
mlp = compose (layer n1 m) (makeMLP n1 (n2 :: ns)) -- <<<< | |
fwd' : (HVect (ParaChain m (n1 :: n2 :: ns)), V m) -> V (last (n1 :: n2 :: ns)) | |
fwd' (p1 :: ps, vm) = mlp.fwd ((p1, ps), vm) | |
bwd' : (HVect (ParaChain m (n1 :: n2 :: ns)), V m, V (last (n1 :: n2 :: ns))) -> | |
(HVect (ParaChain m (n1 :: n2 :: ns)), V m) | |
bwd' (pmn1 :: pmns, s, a) = | |
let ((pmn1', pmns'), s') = mlp.bwd ((pmn1, pmns), s, a) | |
in (pmn1' :: pmns', s') | |
-- xs = [1, 2, 3, 4, 5, 6] | |
-- vw = [[1, 2, 3], [4, 5, 6]] m=3 n=2 | |
reshape : (m : Nat) -> (n : Nat) -> Vect (n * m) a -> Vect n (Vect m a) | |
reshape m Z xs = [] | |
reshape m (S k) xs = take m xs :: reshape m k (drop m xs) | |
-- A connector lens | |
bind : {m : Nat} -> {n : Nat} -> Lens (Vect n (Vect m s)) (Vect (n*m) s) | |
bind = MkLens fwd' bwd' | |
where | |
fwd' : Vect n (Vect m s) -> Vect (n*m) s | |
fwd' vs = concat vs | |
bwd' : {m : Nat} -> {n : Nat} -> (Vect n (Vect m s), Vect (n*m) s) -> (Vect n (Vect m s)) | |
bwd' {m} {n} (vs, w) = (reshape m n w) | |
batchN : (n : Nat) -> | |
(Vect n p -> p) -> | |
PLens p s a -> | |
PLens p (Vect n s) (Vect n a) | |
batchN n collectP lns = | |
MkPLens fwdB bckB | |
where | |
fwdB : (p, Vect n s) -> Vect n a | |
fwdB (p, ss) = map lns.fwd (zip (replicate n p) ss) | |
bckB : (p, Vect n s, Vect n a) -> (p, Vect n s) | |
bckB (p, ss, as) = | |
let (ps', ss') = unzip $ map lns.bwd $ zip3 (replicate n p) ss as | |
in (collectP ps', ss') | |
-- mean square error 0.5 * Sum (si - gi)^2 | |
-- derivative: d/dsi = (si - gi) | |
delta : V n -> V n -> Double | |
delta s g = 0.5 * (sum $ map (\x => x * x) (zipWith (-) s g)) | |
loss : V n -> Lens (V n) Double | |
loss gtruth = MkLens (\s => delta s gtruth) | |
(\(s, a) => backLoss gtruth s a) | |
where | |
backLoss : V n -> V n -> Double -> V n | |
backLoss g s a = map ( a *) (zipWith (-) s g) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment