Skip to content

Instantly share code, notes, and snippets.

Created February 20, 2016 03:09
Show Gist options
  • Save jeffheaton/e6ce48b06da3ea2ef507 to your computer and use it in GitHub Desktop.
Save jeffheaton/e6ce48b06da3ea2ef507 to your computer and use it in GitHub Desktop.
open DiffSharp.AD.Float64
open DiffSharp.Util
// A layer of neurons
type Layer =
{mutable W:DM // Weight matrix
mutable b:DV // Bias vector
a:DV->DV} // Activation function
// A feedforward network of several layers
type Network =
{layers:Layer[]} // The layers forming this network
let runLayer (x:DV) (l:Layer) =
l.W * x + l.b |> l.a
let runNetwork (x:DV) (n:Network) =
Array.fold runLayer x n.layers
let rnd = System.Random()
// Initialize a fully connected feedforward neural network
// Weights and biases between -0.5 and 0.5
// l : number of inputs, followed by the number of neurons in each subsequent layer
let createNetwork (l:int[]) =
{layers = Array.init (l.Length - 1) (fun i ->
{W = DM.init l.[i + 1] l.[i] (fun _ _ -> -0.5 + rnd.NextDouble())
b = DV.init l.[i + 1] (fun _ -> -0.5 + rnd.NextDouble())
a = sigmoid})}
// The backpropagation algorithm
// n: network to be trained
// eta: learning rate
// epochs: number of training epochs
// x: training input vectors
// y: training target vectors
let backprop (n:Network) eta epochs (x:DV[]) (y:DV[]) =
let i = DiffSharp.Util.GlobalTagger.Next
seq {for j in 0 .. epochs do
for l in n.layers do
l.W <- l.W |> makeReverse i
l.b <- l.b |> makeReverse i
let L = Array.map2 (fun x y -> DV.l2normSq (y - runNetwork x n)) x y |> Array.sum
L |> reverseProp (D 1.) // Propagate adjoint value 1 backward
for l in n.layers do
l.W <- primal (l.W.P - eta * l.W.A)
l.b <- primal (l.b.P - eta * l.b.A)
printfn "Iteration %i, loss %f" j (float L)
yield float L}
open FSharp.Charting
let ORx = [|toDV [0.; 0.]
toDV [0.; 1.]
toDV [1.; 0.]
toDV [1.; 1.]|]
let ORy = [|toDV [0.]
toDV [1.]
toDV [1.]
toDV [1.]|]
// 2 inputs, one layer with one neuron
let net2 = createNetwork [|2; 1|]
// Train
let train2 = backprop net2 0.9 1000 ORx ORy
// Plot the error during training
Chart.Line train2
val net2 : Network = {layers = [|{W = DM [[0.230677625; 0.1414874814]];
b = DV [|0.4233988253|];}|];}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment