Skip to content

Instantly share code, notes, and snippets.

Created May 2, 2017 14:06
Show Gist options
  • Save AnneMGal/01c46a6ff8a067bb9ea2bd41b8d3c5f2 to your computer and use it in GitHub Desktop.
Save AnneMGal/01c46a6ff8a067bb9ea2bd41b8d3c5f2 to your computer and use it in GitHub Desktop.
// Daniel Shiffman
// Nature of Code: Intelligence and Learning
// Based on "Make Your Own Neural Network" by Tariq Rashid
// Neural Network
var nn;
// Train and Testing Data
var training;
var testing;
// Where are we in the training and testing data
// (for animation)
var trainingIndex = 0;
var testingIndex = 0;
// How many times through all the training data
var epochs = 0;
// Network configuration
var input_nodes = 784;
var hidden_nodes = 256;
// for numbers
//var output_nodes = 10;
// for ascii
var output_nodes = 127;
// Learning rate
var learning_rate = 0.1;
// How is the network doing
var totalCorrect = 0;
var totalGuesses = 0;
// Reporting status to a paragraph
var statusP;
// This is for a user drawn image
var userPixels;
var smaller;
var ux = 16;
var uy = 100;
var uw = 140;
// Load training and testing data
// Note this is not the full dataset
// From:
function preload() {
training = loadStrings('data/mnist_train_10000.csv');
testing = loadStrings('data/mnist_test_1000.csv');
function setup() {
// Canvas
createCanvas(320, 280);
// pixelDensity(1);
// Create the neural network
nn = new NeuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)
// Status paragraph
statusP = createP('');
var pauseButton = createButton('pause');
// Toggle the state to start and stop
function toggle() {
if (pauseButton.html() == 'pause') {
} else {
// This button clears the user pixels
var clearButton = createButton('clear');
// Just draw a black background
function clearUserPixels() {
// Save the model
var saveButton = createButton('save model');
// Save all the model is a JSON file
// TODO: add reloading functionality!
function saveModelJSON() {
// Take the neural network object and download
saveJSON(nn, 'model.json');
// Create a blank user canvas
userPixels = createGraphics(uw, uw);
// Create a smaller 28x28 image
smaller = createImage(28, 28, RGB);
// This is sort of silly, but I'm copying the user pixels
// so that we see a blank image to start
var img = userPixels.get();
smaller.copy(img, 0, 0, uw, uw, 0, 0, smaller.width, smaller.height);
// When the mouse is dragged, draw onto the user pixels
function mouseDragged() {
// Only if the user drags within the user pixels area
if (mouseX > ux && mouseY > uy && mouseX < ux + uw && mouseY < uy + uw) {
// Draw a white circle
userPixels.ellipse(mouseX - ux, mouseY - uy, 16, 16);
// Sample down into the smaller p5.Image object
var img = userPixels.get();
smaller.copy(img, 0, 0, uw, uw, 0, 0, smaller.width, smaller.height);
function draw() {
// Train (this does just one image per cycle through draw)
var traindata = train();
// Test
var result = test();
// The results come back as an array of 3 things
// Input data
var testdata = result[0];
// What was the guess?
var guess = result[1];
// Was it correct?
var correct = result[2];
// Draw the training and testing image
drawImage(traindata, ux, 16, 2, 'training');
drawImage(testdata, 180, 16, 2, 'test');
// Draw the resulting guess
rect(246, 16, 2 * 28, 2 * 28);
// Was it right or wrong?
if (correct) {
fill(0, 255, 0);
} else {
fill(255, 0, 0);
// show the raw ascii int *
text(guess, 230, 264);
// convert to the string *
text(String.fromCharCode(guess), 257, 64);
// String.fromCharCode(48) == '0'
// Tally total correct
if (correct) {
// Show performance and # of epochs
var status = 'performance: ' + nf(totalCorrect / totalGuesses, 0, 2);
status += '<br>';
// Percent correct since the sketch began
var percent = 100 * trainingIndex / training.length;
status += 'epochs: ' + epochs + ' (' + nf(percent, 1, 2) + '%)';
// Draw the user pixels
image(userPixels, ux, uy);
text('draw here', ux, uy + uw + 16);
// Draw the sampled down image
image(smaller, 180, uy, 28 * 2, 28 * 2);
// Change the pixels from the user into network inputs
var inputs = [];
for (var i = 0; i < smaller.pixels.length; i += 4) {
// Just using the red channel since it's a greyscale image
// Not so great to use inputs of 0 so smallest value is 0.01
inputs[i / 4] = map(smaller.pixels[i], 0, 255, 0, 0.99) + 0.01;
// Get the outputs
var outputs = nn.query(inputs);
// What is the best guess?
var guess = findMax(outputs);
// Draw the resulting guess
rect(246, uy, 2 * 28, 2 * 28);
text(String.fromCharCode(guess), 258, uy + 48);
// Function to train the network
function train() {
// Grab a row from the CSV
var values = training[trainingIndex].replace(/\s+/g, '').split(',');
//text(values[0], 0, 50);
values[0] = values[0].charCodeAt();
// print ascii code that is actually used for training
text(values[0], 75, 66);
// Make an input array to the neural network
var inputs = [];
// Starts at index 1
for (var i = 1; i < values.length; i++) {
// Normalize the inputs 0-1, not so great to use inputs of 0 so add 0.01
inputs[i - 1] = map(Number(values[i]), 0, 255, 0, 0.99) + 0.01;
// Now create an array of targets
targets = [];
// Everything by default is wrong
for (var k = 0; k < output_nodes; k++) {
targets[k] = 0.01;
// The first spot is the class
var label = Number(values[0]);
//text(label, 0, 100);
// So it should get a 0.99 output
targets[label] = 0.99;
// Train with these inputs and targets
nn.train(inputs, targets);
// Go to the next training data point
if (trainingIndex == training.length) {
trainingIndex = 0;
// Once cycle through all training data is one epoch
// Return the inputs to draw them
return inputs;
// Function to test the network
function test() {
// Grab a row from the CSV
//var values = training[testingIndex].split(',');
var values = testing[testingIndex].replace(/\s+/g, '').split(',');
//values[0] = int(values[0]);
values[0] = values[0].charCodeAt();
// Make an input array to the neural network
var inputs = [];
// Starts at index 1
for (var i = 1; i < values.length; i++) {
// Normalize the inputs 0-1, not so great to use inputs of 0 so add 0.01
inputs[i - 1] = map(Number(values[i]), 0, 255, 0, 0.99) + 0.01;
// The first spot is the class
var label = Number(values[0]);
// Run the data through the network
var outputs = nn.query(inputs);
// Find the index with the highest probability
var guess = findMax(outputs);
// Was the network right or wrong?
var correct = false;
if (guess == label) {
correct = true;
// Switch to a new testing data point every so often
if (frameCount % 30 == 0) {
if (testingIndex == testing.length) {
testingIndex = 0;
// For reporting in draw return the results
return [inputs, guess, correct];
// A function to find the maximum value in an array
function findMax(list) {
// Highest so far?
var record = 0;
var index = 0;
// Check every element
for (var i = 0; i < list.length; i++) {
// Higher?
if (list[i] > record) {
record = list[i];
index = i;
// Return index of highest
return index;
// Draw the array of floats as an image
function drawImage(values, xoff, yoff, w, txt) {
// it's a 28 x 28 image
var dim = 28;
// For every value
for (var k = 0; k < values.length; k++) {
// Scale up to 256
var brightness = values[k] * 256;
// Find x and y
var x = k % dim;
var y = floor(k / dim);
// Draw rectangle
rect(xoff + x * w, yoff + y * w, w, w);
// Draw a label below
text(txt, xoff, yoff + w * 35);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment