Skip to content

Instantly share code, notes, and snippets.

@BartoszMilewski
Last active November 5, 2023 21:21
Show Gist options
  • Save BartoszMilewski/2e2f60dd4abfdda86dc678d1ee85f487 to your computer and use it in GitHub Desktop.
Save BartoszMilewski/2e2f60dd4abfdda86dc678d1ee85f487 to your computer and use it in GitHub Desktop.
Neural Networks in Idris
module HVect
import Data.Vect
-- Heterogeneous vector
public export
data HVect : Vect n Type -> Type where
Nil : HVect Nil
(::) : h -> HVect t -> HVect (h :: t)
export
Show (HVect []) where
show Nil = "\n"
export
(Show t, Show (HVect ts)) => Show (HVect (t :: ts)) where
show (x :: xs) = show x ++ " :: " ++ show xs
export
Semigroup (HVect []) where
[] <+> [] = []
export
(Semigroup t, Semigroup (HVect ts)) =>
Semigroup (HVect (t :: ts)) where
(a :: as) <+> (b :: bs) = (a <+> b) :: (as <+> bs)
export
Monoid (HVect []) where
neutral = []
export
(Monoid t, Monoid (HVect ts)) => Monoid (HVect (t :: ts)) where
neutral = neutral :: neutral
public export
interface (Semigroup v, Monoid v) => VSpace v where
scale : Double -> v -> v
export
VSpace (HVect []) where
scale a Nil = Nil
-- Replicate a vector of types
-- map (Vect k) ts
public export
ReplTypes : (k : Nat) -> (ts : Vect l Type) -> Vect l Type
ReplTypes k [] = []
ReplTypes k (t' :: ts') = Vect k t' :: ReplTypes k ts'
-- Concatenate vectors of heterogeneous monoid types
export
concatH : {k : Nat} -> {l : Nat} -> {ts : Vect l Type} -> {isMono : HVect (map Monoid ts)} ->
HVect (ReplTypes k ts) -> HVect ts
concatH {l = 0} {ts = []} Nil = Nil
concatH {ts = t' :: ts'} {isMono = (_ :: pfs)} (v :: vs) = concat v :: concatH {isMono = pfs} vs
export
emptyVTypes : {l : Nat} -> (ts : Vect l Type) -> HVect (ReplTypes 0 ts)
emptyVTypes [] = Nil
emptyVTypes (t' :: ts') = [] :: emptyVTypes ts'
-- Generalization of zipWith (::)
export
zipCons : {k : Nat} -> {l : Nat} -> {ts : Vect l Type} ->
HVect ts -> HVect (ReplTypes k ts) -> HVect (ReplTypes (S k) ts)
zipCons [] [] = []
zipCons (t' :: ts') (vs :: vss) = (t' :: vs) :: zipCons ts' vss
-- Transpose a vector whose entries are heterogeneous vectors
export
transposeH : {k : Nat} -> {l : Nat} -> {ts : Vect l Type} ->
Vect k (HVect ts) -> HVect (ReplTypes k ts)
transposeH {k=0} {ts} [] = emptyVTypes ts
transposeH (h :: hs) = zipCons h (transposeH hs)
module NNet
import Data.Vect
import HVect
-- Vector of Double
V : (n : Nat) -> Type
V n = Vect n Double
-- Parameters for an affine lens of
-- m inputs, one output
record Para (m : Nat) where
constructor MkPara
weight : V m
bias : Double
-- vector of n parameters, each m wide
-- (n * m inputs, n outputs)
ParaBlock : (m : Nat) -> (n : Nat) -> Type
ParaBlock m n = Vect n (Para m)
-- Chain of parameter blocks
-- Parameters for multi-layer perceptron with m inputs
-- A chain of types ParaBlock m n1 :: ParaBlock n1 n2 :: ParaBlock n2 n3 ...
ParaChain : (m : Nat) -> (ns : Vect l Nat) -> Vect l Type
ParaChain m [] = []
ParaChain m (n' :: ns') = ParaBlock m n' :: ParaChain n' ns'
-- Chain of vectors of parameter blocks (for batches of perceptrons)
VParaChain : (k : Nat) -> (m : Nat) -> (ns : Vect l Nat) -> Vect l Type
VParaChain k m ns = ReplTypes k (ParaChain m ns)
-------------
-- Interfaces
-------------
{m : Nat} -> Show (Para m) where
show pa = "weight: " ++ show (weight pa) ++ " bias: " ++ show (bias pa) ++ "\n"
-- Semigroup
Semigroup Double where
x <+> y = x + y
Semigroup (Para m) where
(MkPara w b) <+> (MkPara w' b') = MkPara (zipWith (+) w w') (b + b')
-- Monoid
Monoid Double where
neutral = 0.0
{m : Nat} -> Monoid (Para m) where
neutral = MkPara (replicate m 0.0) 0.0
-- Proof that every type in ParaChain is a Monoid
isMonoChain : {m : Nat} -> (ns : Vect l Nat) -> HVect (map Monoid (ParaChain m ns))
isMonoChain Nil = Nil
isMonoChain (n' :: ns') = MkMonoid neutral :: isMonoChain ns'
-- Vector Space
{m : Nat} -> VSpace (Para m) where
scale a (MkPara w b) = MkPara (map (a *) w) (a * b)
{m : Nat} -> {n : Nat} -> VSpace (ParaBlock m n) where
scale a v = map (scale a) v
-- in (ParaChain m ns), all types are Monoid
collectH : {k : Nat} -> {m : Nat} -> {l : Nat} -> {ns : Vect l Nat} -> HVect (VParaChain k m ns) -> HVect (ParaChain m ns)
collectH hv = concatH {isMono = isMonoChain ns} hv
collectParas : {k : Nat} -> {m : Nat} -> {l : Nat} -> {ns : Vect l Nat} ->
Vect k (HVect (ParaChain m ns)) -> HVect (ParaChain m ns)
collectParas = collectH . transposeH
----------
-- Parametric lens
-- record PLens p p' s s' a a'
-- fwd : (p, s) -> a
-- lens1.bwd : (p, s, a') -> (p', s')
record PLens p s a where
constructor MkPLens
fwd : (p, s) -> a
bwd : (p, s, a) -> (p, s)
-- Special case of parametric lens with p = ()
-- Simplifies composition
record Lens s a where
constructor MkLens
fwd0 : s -> a
bwd0 : (s, a) -> s
-- Composition of parametric lenses
compose : PLens p s a -> PLens q a b ->
PLens (p, q) s b
-- lens1.fwd : (p, s) -> a
-- lens1.bwd : (p, s, a) -> (p, s)
-- lens2.fwd : (q, a) -> b
-- lens2.bwd : (q, a, b) -> (q, a)
compose lens1 lens2 = MkPLens fwd' bwd'
where
fwd' : ((p, q), s) -> b
fwd' ((p, q), s) = lens2.fwd (q, lens1.fwd (p, s))
bwd' : ((p, q), s, b) -> ((p, q), s)
bwd' ((p, q), s, b) =
let (q', a') = lens2.bwd (q, lens1.fwd (p, s), b)
(p', s') = lens1.bwd (p, s, a')
in ((p', q'), s')
-- Helpers for composing a parametric lens with a non-parametric one
composeR : PLens p s a -> Lens a b ->
PLens p s b
composeR lens1 lens2 = MkPLens fwd' bwd'
where
fwd' : (p, s) -> b
fwd' (p, s) = lens2.fwd0 (lens1.fwd (p, s))
bwd' : (p, s, b) -> (p, s)
bwd' (p, s, b) =
let a' = lens2.bwd0 (lens1.fwd (p, s), b)
(p', s') = lens1.bwd (p, s, a')
in (p', s')
composeL : Lens s a -> PLens p a b ->
PLens p s b
-- lens1.fwd : s -> a
-- lens1.bwd : (s, a) -> s
-- lens2.fwd : (p, a) -> b
-- lens2.bwd : (p, a, b) -> (p, a)
composeL lens1 lens2 = MkPLens fwd' bwd'
where
fwd' : (p, s) -> b
fwd' (p, s) = lens2.fwd (p, lens1.fwd0 s)
bwd' : (p, s, b) -> (p, s)
bwd' (p, s, b) =
let (p', a') = lens2.bwd (p, lens1.fwd0 s, b)
s' = lens1.bwd0 (s, a')
in (p', s')
-- Product of parametric lenses,
prodLens :
PLens p s a ->
PLens p' s' a' ->
PLens (p, p') (s, s') (a, a')
-- lens1.fwd : (p, s) -> a
-- lens1.bwd : (p, s, a) -> (p, s)
prodLens lens1 lens2 =
MkPLens fwdProd bwdProd
where
fwdProd : ((p, p'), (s, s')) -> (a, a')
fwdProd ((p, p'), (s, s')) = (lens1.fwd (p, s), lens2.fwd (p', s'))
bwdProd : ((p, p'), (s, s'), (a, a')) -> ((p, p'), (s, s'))
bwdProd ((p, p'), (s, s'), (a, a')) =
let (q, t) = lens1.bwd (p, s, a)
(q', t') = lens2.bwd (p', s', a')
in ((q, q'), (t, t'))
-- duplicate a lens in parallel n+1 times
vecLens : (n : Nat) -> PLens p s a -> PLens (Vect n p) (Vect n s) (Vect n a)
vecLens Z _ = MkPLens (\(Nil, Nil) => Nil) (\(Nil, Nil, Nil) => (Nil, Nil))
vecLens (S n) lns = MkPLens fwd' bwd'
where
lnsN : PLens (Vect n p) (Vect n s) (Vect n a)
lnsN = vecLens n lns
fwd' : (Vect (S n) p, Vect (S n) s) -> Vect (S n) a
fwd' (p :: ps, s :: ss) = lns.fwd (p, s) :: lnsN.fwd (ps, ss)
bwd' : (Vect (S n) p, Vect (S n) s, Vect (S n) a) -> (Vect (S n) p, Vect (S n) s)
bwd' (p :: ps, s :: ss, a :: as) =
let (p', s') = lns.bwd (p, s, a)
(ps', ss') = lnsN.bwd (ps, ss, as)
in (p' :: ps', s' :: ss')
-- A branching combinator
branch : Monoid s => (n : Nat) -> Lens s (Vect n s)
branch n = MkLens (replicate n) (\(_, ss) => concat ss) -- pointwise <+>
-- Batch n lenses in parallel sharing the same parameters
-- input and output are n-tupled, parameters are collected
batch : Monoid p =>
(n : Nat) ->
PLens p s a ->
PLens p (Vect n s) (Vect n a)
batch n lns =
MkPLens fwdB bwdB
where
fwdB : (p, Vect n s) -> Vect n a
fwdB (p, ss) = map lns.fwd (zip (replicate n p) ss)
bwdB : (p, Vect n s, Vect n a) -> (p, Vect n s)
bwdB (p, ss, as) =
let (ps', ss') = unzip $ map lns.bwd $ zip3 (replicate n p) ss as
in (concat ps', ss')
-------------------------------------
------- Vector parametric lenses ----
-------------------------------------
-- activation lens using tanh (no parameters)
activ : Lens Double Double
activ = MkLens (\s => tanh s)
(\(s, a) => a * (1 - (tanh s)*(tanh s))) -- a * da/ds
-- Affine parametric lens (a composition of linear and bias)
affine : (m : Nat) -> PLens (Para m) (V m) Double
affine n = MkPLens fwd' bwd'
where
fwd' : (Para m, V m) -> Double
fwd' (p, s) = foldl (+) (bias p) (zipWith (*) (weight p) s) -- a = b + w * s
bwd' : (Para m, V m, Double) -> (Para m, V m)
bwd' (p, s, a) = ( MkPara (map (a *) s) a -- (da/dw, da/db)
, map (a *) (weight p)) -- da/ds
-- Neuron with m inputs and one output
-- affine : PLens (Para m) (V m) Double
-- activ : Lens Double Double
-- composite : PLens (Para m) (V m) Double
neuron : (m : Nat) -> PLens (Para m) (V m) Double
neuron m = composeR (affine m) activ
-- A layer of neurons
-- n neurons with m inputs each
-- 1 2 .. n
-- | | |
-- m m m
-- \ / \ /
-- m
-- ParaBlock m n = Vect n (Para m)
-- neuron m : PLens (Vect n (Para 1)) (V m) (V 1)
-- vecLens n (neuron m): PLens (Vect n (Vect n (Para 1))) (Vect n (V m)) (Vect n (V 1))
-- branch n : Lens (V m) (Vect n (V m))
-- s a
-- composeL : Lens s a -> PLens p a b -> PLens p s b
layer : (n : Nat) -> (m : Nat) -> PLens (Vect n (Para m)) (V m) (V n)
layer n m = composeL (branch n) (vecLens n (neuron m))
-- m -> [m, n1] -> [n1, n2] -> ... [n l, n (l+1)]
-- Multi layer perceptron with m inputs and l+1 layers
-- neuron count in each layer is given by (Vect l Nat)
-- 1 2 .. n2 [n2]
-- n1 n1 n1
-- |/ \|/ \|
-- 1 2 .. n1 [n1] <-n1- [P[m], P[m] .. P[m]]
-- m m m
-- \ / \ /
-- m
makeMLP : (m : Nat) -> {l : Nat} -> (ls : Vect (S l) Nat) -> -- << architecture
PLens (HVect (ParaChain m ls)) (V m) (V (last ls))
makeMLP m (n :: []) = MkPLens fwd' bwd'
where
lr : PLens (ParaBlock m n) (V m) (V n)
lr = layer n m
fwd' : (HVect (ParaChain m [n]), V m) -> V (n)
fwd' ([p], v) = lr.fwd (p, v)
bwd' : (HVect (ParaChain m [n]), V m, V n) -> (HVect (ParaChain m [n]), V m)
bwd' ([p], v, w) = let (p', v') = lr.bwd (p, v, w)
in ([p'], v')
makeMLP m (n1 :: n2 :: ns) = MkPLens fwd' bwd'
where
-- m -> [m, n1] -> [n1, n2] -> ... [n l, n (l+1)]
mlp : PLens (ParaBlock m n1, HVect (ParaChain n1 (n2 :: ns)))
(V m)
(V (last (n2 :: ns)))
mlp = compose (layer n1 m) (makeMLP n1 (n2 :: ns)) -- <<<<
fwd' : (HVect (ParaChain m (n1 :: n2 :: ns)), V m) -> V (last (n1 :: n2 :: ns))
fwd' (p1 :: ps, vm) = mlp.fwd ((p1, ps), vm)
bwd' : (HVect (ParaChain m (n1 :: n2 :: ns)), V m, V (last (n1 :: n2 :: ns))) ->
(HVect (ParaChain m (n1 :: n2 :: ns)), V m)
bwd' (pmn1 :: pmns, s, a) =
let ((pmn1', pmns'), s') = mlp.bwd ((pmn1, pmns), s, a)
in (pmn1' :: pmns', s')
-- xs = [1, 2, 3, 4, 5, 6]
-- vw = [[1, 2, 3], [4, 5, 6]] m=3 n=2
reshape : (m : Nat) -> (n : Nat) -> Vect (n * m) a -> Vect n (Vect m a)
reshape m Z xs = []
reshape m (S k) xs = take m xs :: reshape m k (drop m xs)
-- A connector lens
bind : {m : Nat} -> {n : Nat} -> Lens (Vect n (Vect m s)) (Vect (n*m) s)
bind = MkLens fwd' bwd'
where
fwd' : Vect n (Vect m s) -> Vect (n*m) s
fwd' vs = concat vs
bwd' : {m : Nat} -> {n : Nat} -> (Vect n (Vect m s), Vect (n*m) s) -> (Vect n (Vect m s))
bwd' {m} {n} (vs, w) = (reshape m n w)
batchN : (n : Nat) ->
(Vect n p -> p) ->
PLens p s a ->
PLens p (Vect n s) (Vect n a)
batchN n collectP lns =
MkPLens fwdB bckB
where
fwdB : (p, Vect n s) -> Vect n a
fwdB (p, ss) = map lns.fwd (zip (replicate n p) ss)
bckB : (p, Vect n s, Vect n a) -> (p, Vect n s)
bckB (p, ss, as) =
let (ps', ss') = unzip $ map lns.bwd $ zip3 (replicate n p) ss as
in (collectP ps', ss')
-- mean square error 0.5 * Sum (si - gi)^2
-- derivative: d/dsi = (si - gi)
delta : V n -> V n -> Double
delta s g = 0.5 * (sum $ map (\x => x * x) (zipWith (-) s g))
loss : V n -> Lens (V n) Double
loss gtruth = MkLens (\s => delta s gtruth)
(\(s, a) => backLoss gtruth s a)
where
backLoss : V n -> V n -> Double -> V n
backLoss g s a = map ( a *) (zipWith (-) s g)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment