Skip to content

Instantly share code, notes, and snippets.

@kkoch986
Created December 2, 2014 14:09
Show Gist options
  • Save kkoch986/987b0226d56e56fb2714 to your computer and use it in GitHub Desktop.
Save kkoch986/987b0226d56e56fb2714 to your computer and use it in GitHub Desktop.
playing with neural nets
var Unit = function(value, grad) {
this.value = value;
this.grad = grad;
};
var multiplyGate = function(){ };
multiplyGate.prototype = {
forward: function(u0, u1) {
// store pointers to input Units u0 and u1 and output unit utop
this.u0 = u0;
this.u1 = u1;
this.utop = new Unit(u0.value * u1.value, 0.0);
return this.utop;
},
backward: function() {
// take the gradient in output unit and chain it with the
// local gradients, which we derived for multiply gate before
// then write those gradients to those Units.
this.u0.grad += this.u1.value * this.utop.grad;
this.u1.grad += this.u0.value * this.utop.grad;
}
};
var addGate = function(){ };
addGate.prototype = {
forward: function(u0, u1) {
this.u0 = u0;
this.u1 = u1; // store pointers to input units
this.utop = new Unit(u0.value + u1.value, 0.0);
return this.utop;
},
backward: function() {
// add gate. derivative wrt both inputs is 1
this.u0.grad += 1 * this.utop.grad;
this.u1.grad += 1 * this.utop.grad;
}
};
var sigmoidGate = function() {
// helper function
this.sig = function(x) { return 1 / (1 + Math.exp(-x)); };
};
sigmoidGate.prototype = {
forward: function(u0) {
this.u0 = u0;
this.utop = new Unit(this.sig(this.u0.value), 0.0);
return this.utop;
},
backward: function() {
var s = this.sig(this.u0.value);
this.u0.grad += (s * (1 - s)) * this.utop.grad;
}
};
// --------------------------------------------------------------------------------
// --------------------------------------------------------------------------------
// --------------------------------------------------------------------------------
// Tweakables
var step_size = 0.000001;
var max_iterations = 10000;
var training_data_size = 100000;
var training_data_input_range = [-10000, 10000];
// Generate training data
var data = [];
function model(x, y) {
return ((2 * x) + (-6 * y) + 40) > 0 ? 1 : -1;
}
var training_data = [];
for(var i = 0 ; i < training_data_size ; i++) {
var x = Math.random(training_data_input_range[0], training_data_input_range[1]);
var y = Math.random(training_data_input_range[0], training_data_input_range[1]);
training_data.push([x, y, model(x,y)]);
}
// create input units
var a = new Unit(1.0, 0.0);
var b = new Unit(2.0, 0.0);
var c = new Unit(-3.0, 0.0);
var x = new Unit(-1.0, 0.0);
var y = new Unit(3.0, 0.0);
var target = -10.0;
// create the gates
var mulg0 = new multiplyGate();
var mulg1 = new multiplyGate();
var addg0 = new addGate();
var addg1 = new addGate();
// do the forward pass
var forwardNeuron = function() {
ax = mulg0.forward(a, x); // a*x = -1
by = mulg1.forward(b, y); // b*y = 6
axpby = addg0.forward(ax, by); // a*x + b*y = 5
axpbypc = addg1.forward(axpby, c); // a*x + b*y + c = 2
return axpbypc.value;
};
for(var j = 0 ; j < max_iterations ; j++) {
var total = 0;
var correct = 0;
for (var i = training_data.length - 1; i >= 0; i--) {
var test = training_data[i];
total++;
// Run the current inputs through the circuit
x.value = test[0];
y.value = test[1];
var current_value = forwardNeuron();
// Backpropagate to find the gradient
var classification = ( current_value / Math.abs(current_value) );
var expected = test[2];
axpbypc.grad = 0;
if(expected < 0 && classification > 0) {
axpbypc.grad = -1;
} else if(expected > 0 && classification < 0) {
axpbypc.grad = 1;
}
addg1.backward(); // writes gradients into axpby and c
addg0.backward(); // writes gradients into ax and by
mulg1.backward(); // writes gradients into b and y
mulg0.backward(); // writes gradients into a and x
// apply the adjustment
a.value += step_size * a.grad; // a.grad is -0.105
b.value += step_size * b.grad; // b.grad is 0.315
c.value += step_size * c.grad; // c.grad is 0.105
var newValue = forwardNeuron();
var classification = newValue / Math.abs(newValue);
// console.log(test, current_value, axpbypc.grad, classification, classification === expected);
if(classification === expected) {
correct++;
}
}
console.log("[" + j + "]: " + total + " total, " + correct + " correct ( " + ( (correct / total) * 100.0 ).toFixed(2) + "% )");
if(correct === total) {
break ;
}
}
console.log("A: ", a.value > 1 ? "Increase" : "Decrease", " 1 -> ", a.value.toFixed(2));
console.log("B: ", b.value > 2 ? "Increase" : "Decrease", " 2 -> ", b.value.toFixed(2));
console.log("C: ", c.value > -3 ? "Increase" : "Decrease", "-3 -> ", c.value.toFixed(2));
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment