Last active
July 2, 2017 15:18
-
-
Save tpgmartin/e3c81565f2a161d7fd2370c2ea870cad to your computer and use it in GitHub Desktop.
Example feed forward network learning XOR
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const nj = require('numjs') | |
// The activation function of choice, for a given input x, the function will return either 0, if x < 0, or x. | |
// This is used to find the activation of the hidden layer nodes during forward propagation. | |
function relu(x) { | |
return iterator(x, x => ((x > 0) * x)) | |
} | |
// The derivative of the activation function above, this is used during the backward propagation and gradient descent process to find | |
// the updated for weights between the input and hidden layer nodes. | |
function reluDeriv(x) { | |
return iterator(x, x => ((x > 0) ? 1 : 0)) | |
} | |
// A helper method to apply either of the functions above to an numjs array of arbitrary size. | |
function iterator(x, fn) { | |
let out = x.slice().tolist() | |
for (let i = 0; i < out.length; i++) { | |
for (let j = 0; j < out[i].length; j++) { | |
out[i][j] = fn(out[i][j]) | |
} | |
} | |
return nj.array(out) | |
} | |
// The training data for the XOR gate, e.g. XOR with inputs A=0, B=0 returns 0 | |
const inputs = nj.array([ | |
[0, 0], | |
[0, 1], | |
[1, 0], | |
[1, 1] | |
]) | |
const outputs = nj.array([[0, 1, 1, 0]]).T | |
// The learning rate, alpha | |
const alpha = 0.2 | |
// The number of hidden layer nodes in network | |
const hiddenSize = 3 | |
// The weights connecting the layers, randomly generated between -1 and +1 | |
let inputHiddenWeights = nj.random([2, hiddenSize]).multiply(2).subtract(1) | |
let hiddenOutputWeights = nj.random([hiddenSize, 1]).multiply(2).subtract(1) | |
// The backpropagation learning algoritm run over an arbitrary number of iterations | |
for (let iteration = 0; iteration < 60; iteration++) { | |
// Network error | |
let error = 0 | |
for (let i = 0; i < inputs.shape[0]; i++) { | |
// Forward propagation, find activation at each layer, or just read from training data at input layer | |
let inputLayer = inputs.slice([i, i + 1]) | |
let hiddenLayer = relu(nj.dot(inputLayer, inputHiddenWeights)) | |
let outputLayer = nj.dot(hiddenLayer, hiddenOutputWeights) | |
// Network error calculated using squared error | |
error = nj.add(error, nj.sum(nj.power((nj.subtract(outputLayer, outputs.slice([i, i + 1]))), 2))) | |
// Calculate weight update by finding derivative of error function with respect to weight and subtract from previous weight. | |
// The "delta" variables are just for the sake of reusability. | |
let outputLayerDelta = nj.subtract(outputLayer, outputs.slice([i, i + 1])) | |
let hiddenLayerDelta = nj.multiply(outputLayerDelta.dot(hiddenOutputWeights.T), reluDeriv(hiddenLayer)) | |
hiddenOutputWeights = nj.subtract(hiddenOutputWeights, hiddenLayer.T.dot(outputLayerDelta).multiply(alpha)) | |
inputHiddenWeights = nj.subtract(inputHiddenWeights, inputLayer.T.dot(hiddenLayerDelta).multiply(alpha)) | |
} | |
// This is just for bookkeeping | |
if (iteration % 10 == 9) { | |
console.log(`Error: ${error.tolist()}`) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment