Skip to content

Instantly share code, notes, and snippets.

@antimatter15
Created August 26, 2014 07:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save antimatter15/3ad2c621bc2e342c6805 to your computer and use it in GitHub Desktop.
Save antimatter15/3ad2c621bc2e342c6805 to your computer and use it in GitHub Desktop.
convnet.js training ocr
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge,chrome=1">
<title>ConvNetJS OCR demo</title>
<meta name="description" content="">
<meta name="author" content="">
<canvas id="blah" width="28" height="28"></canvas>
<script src="http://ajax.googleapis.com/ajax/libs/webfont/1/webfont.js"></script>
<style>
.layer {
border: 1px solid #999;
margin-bottom: 5px;
text-align: left;
padding: 10px;
}
.layer_act {
width: 500px;
float: right;
}
.ltconv {
background-color: #FDD;
}
.ltrelu {
background-color: #FDF;
}
.ltpool {
background-color: #DDF;
}
.ltsoftmax {
background-color: #FFD;
}
.ltfc {
background-color: #DFF;
}
.ltlrn {
background-color: #DFD;
}
.ltdropout {
background-color: #AAA;
}
.ltitle {
color: #333;
font-size: 18px;
}
.actmap {
margin: 1px;
}
#trainstats {
text-align: left;
}
.clear {
clear: both;
}
#wrap {
width: 1000px;
margin-left: auto;
margin-right: auto;
}
h1 {
font-size: 16px;
color: #333;
background-color: #DDD;
border-bottom: 1px #999 solid;
text-align: center;
}
.secpart {
width: 400px;
float: left;
}
#lossgraph {
/*border: 1px solid #F0F;*/
width: 100%;
}
.probsdiv canvas {
float: left;
}
.probsdiv {
height: 60px;
width: 180px;
display: inline-block;
font-size: 12px;
box-shadow: 0px 0px 2px 2px #EEE;
margin: 5px;
padding: 5px;
color: black;
}
.pp {
margin: 1px;
padding: 1px;
}
#testset_acc {
margin-bottom: 200px;
}
body {
font-family: Arial, "Helvetica Neue", Helvetica, sans-serif;
}
</style>
<script src="jquery-1.8.3.min.js"></script>
<script src="../build/vis.js"></script>
<script src="../build/util.js"></script>
<script src="../build/convnet.js"></script>
<script>
var google_fonts = ['Droid Sans', 'Philosopher', 'Alegreya Sans', 'Chango', 'Coming Soon', 'Allan', 'Cardo', 'Bubbler One', 'Bowlby One SC', 'Prosto One', 'Rufina', 'Cantora One', 'Denk One', 'Play', 'Architects Daughter', 'Nova Square', 'Inder', 'Gloria Hallelujah', 'Telex', 'Comfortaa', 'Merienda', 'Boogaloo', 'Krona One', 'Orienta', 'Sofadi One', 'Source Sans Pro', 'Revalia', 'Overlock', 'Kelly Slab', 'Rye', 'Lato', 'Milonga', 'Aladin', 'Audiowide', 'Italiana', 'Michroma', 'Cabin Condensed', 'Jura', 'Marko One', 'PT Mono', 'Bubblegum Sans', 'Amaranth']
WebFont.load({
google: {
families: google_fonts
},
active: function(){
fonts = fonts.concat(google_fonts)
}
})
var layer_defs, net, trainer;
// var symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLNOPQRSTUVWXYZ0123456789~!@#$%^&*-+=()[]{}<>\\/:;?.,\'\"';
// var symbols = '
// ab defghi klmn qr t v
// ABCDEFGH JKLMNOPQRSTUVWXYZ
// '
var symbols = [
'a', 'A',
'b', 'B',
'Cc',
'd', 'D',
'e', 'E',
'f', 'F',
'g9', 'G6',
'h', 'H',
'iI1l',
'Jj',
'k', 'K',
'L',
'm', 'M',
'n', 'N',
'Oo0',
'Pp',
'q', 'Q',
'r', 'R',
'Ss5',
't', 'T',
'Uu',
'vV',
'Ww',
'Xx',
'Yy',
'Zz2',
'3',
'4',
'7',
'8',
// '!',
// '@',
// '#',
// '$',
// '%',
// '^',
// '&',
// '*',
// '-',
// '+',
// '=',
// '([{',
// ')]}',
// '(',')','[',']','{','}',
// '<','>','\\', '/', ':;', '?', '.', ',', '\'', '"'
]
var startTime = Date.now()
setInterval(function(){
var name = 'commonequivfreqnoise28x8x18';
var trained_info = {
name: name,
examples_seen: step_num,
age_minutes: (Date.now() - startTime) / 1000 / 60,
train_accuracy: trainAccWindow.get_average(),
validation_accuracy: valAccWindow.get_average()
};
fafnir(name, 'trained_info = ' +JSON.stringify(trained_info, null, ' ')+';\n\ntrained_symbols = ' + JSON.stringify(symbols) + ';\n\ntrained_network = '+ JSON.stringify(net.toJSON()) + ';\n')
}, 60 * 1000)
// var symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLNOPQRSTUVWXYZ0123456789'.split('')
var t = "layer_defs = [];\n\
layer_defs.push({type:'input', out_sx:28, out_sy:28, out_depth:1});\n\
layer_defs.push({type:'conv', sx:5, filters:8, stride:1, pad:2, activation:'relu'});\n\
layer_defs.push({type:'pool', sx:2, stride:2});\n\
layer_defs.push({type:'conv', sx:5, filters:18, stride:1, pad:2, activation:'relu'});\n\
layer_defs.push({type:'pool', sx:5, stride:2});\n\
layer_defs.push({type:'softmax', num_classes:"+symbols.length+"});\n\
\n\
net = new convnetjs.Net();\n\
net.makeLayers(layer_defs);\n\
\n\
trainer = new convnetjs.SGDTrainer(net, {method:'adadelta', batch_size:20, l2_decay:0.001});\n\
";
// ------------------------
// BEGIN OCR SPECIFIC STUFF
// ------------------------
// Comic
// Impact
// Sans
// Serif
function r(range){ return 2 * (Math.random() - 0.5) * range }
var common_fonts = ['Helvetica', 'Helvetica Neue', 'Arial', 'Verdana', 'Arial Black', 'Arial Narrow',
'Myriad Pro', 'Segoe UI', 'Tahoma', 'Trebuchet MS',
'Times New Roman', 'Garamond', 'Futura',
'Courier', 'Courier New', 'Menlo','Comic Sans MS']
var fonts = [
'Helvetica', 'Helvetica Neue', 'Arial', 'Verdana', 'Arial Black', 'Arial Narrow',
'Myriad Pro', 'Segoe UI', 'Tahoma', 'Trebuchet MS',
'Times New Roman', 'Garamond', 'Futura',
'Courier', 'Courier New', 'Menlo','Comic Sans MS',
// 'Comic Sans MS', 'Comic Neue', 'Marker Felt', 'Noteworthy', 'Tekton Pro', 'xkcd',
'Comic Neue','Chalkboard', 'Comic Neue Angular', 'Tekton Pro',
'Source Sans Pro', 'Avenir', 'Avenir Next', 'Gill Sans', 'Lucida Grande',
'Athelas', 'Baskerville', 'Chaparral Pro', 'CMU Serif', 'Cochin', 'Georgia', 'Marion', 'Minion Pro', 'Palatino',
'American Typewriter', 'Andale Mono', 'Arial Rounded',
'Geneva',
'Hoefler Text', 'Iowan Old Style',
'OCR A Std',
'Hobo Std', 'Didot',
'Brandon Grotesque',
'Charter',
'DIN Alternate', 'DIN Condensed',
'Kozuka Gothic Pro', 'Kozuka Mincho Pro', 'Letter Gothic Std', 'Marker Felt', 'Monaco',
'Noteworthy', 'Optima', 'Prestige Elite Std', 'PT Serif', 'Seravek', 'Skia', 'STIXGeneral',
'Superclarendon', 'Thonburi', 'Adobe Caslon Pro', 'Adobe Arabic'
// 'American Typewriter', 'Andale Mono', 'Arial Rounded', 'Blackoak Std', 'Brush Script MT', 'Chalkduster', 'Charlemagne Std', 'College',
// 'Cooper Std', 'Copperplate', 'Courier', 'Didot', 'Futura', 'Geneva', 'Herculanum', 'Hobo Std', 'Hoefler Text', 'Impact', 'Iowan Old Style', 'Menlo', 'Noteworthy', 'OCR A Std'
]
fonts = fonts.concat(common_fonts).concat(common_fonts);
// var famlist = Object.keys(fams);
// upper=[3,8,5,12,11,17,19,13,6,20,22,15,4,7,16,10,24,9,2,1,21,23,14,25,18,26]
// lower=[3,20,12,11,1,15,17,9,6,25,22,10,14,5,4,16,26,8,7,2,13,21,19,23,18,24]
// var etaoin = 'zxqjkgbvpywfmculdrhsnioate'
// var ETAOIN = 'ZXQVKUJGYFOLWHDEPRBNICMAST';
// var etaoin = 'qjzxkvbwygpfmucdlhrsinoate';
var etaoin = 'ZqXjQzVxKkUvJbGwYyFgOpLfWmHuDcEdPlRhBrNsIiCnMoAaStTe';
var joined = symbols.join('').split('').sort(function(b, a){
return etaoin.indexOf(a) - etaoin.indexOf(b)
}).join('')
function sample_training_instance(){
var letter = joined[Math.floor(joined.length * Math.random() * Math.random())];
for(var i = 0; i < symbols.length; i++){
if(symbols[i].indexOf(letter) != -1){
var label = i;
break;
}
}
// var label = symbols.filter(function(e, i){ return })
// var label = Math.floor(symbols.length * Math.random());
// var famfam = famlist[label];
// var fonfam = fams[famfam][Math.floor(fams[famfam].length * Math.random())];
var fonfam = fonts[Math.floor(fonts.length * Math.random())]
var ctx = document.getElementById('blah').getContext('2d');
ctx.fillStyle = 'black'
ctx.fillRect(0, 0, 32, 32);
ctx.textBaseline = 'middle'
ctx.textAlign = 'center'
ctx.font = (100 * Math.floor(1 + 8 * Math.random())) + ' ' + (20 + r(2))+'px "'+fonfam+'"'
// ctx.font = (23 + r(3))+'px "'+fonfam+'"'
ctx.fillStyle = 'white'
ctx.save()
ctx.translate(15 + r(1), 15 + r(1))
ctx.rotate(r(0.1));
// ctx.fillText(symbols[label], 0, 0);
// var letter = symbols[label][Math.floor(symbols[label].length * Math.random())]
// var letter = symbols[label];
ctx.fillText(letter, 0, 0);
ctx.restore()
var p = ctx.getImageData(0, 0, 32, 32).data;
var x = new convnetjs.Vol(32,32,1,0.0);
var W = 32*32;
for(var i=0;i<W;i++) {
x.w[i] = p[i * 4]/255.0;
if(Math.random() < 0.1) x.w[i] = Math.random();
}
x = convnetjs.augment(x, 28);
return {
x:x,
label: label,
isval: Math.random() > 0.9
};
}
// console.log(x)
// symbols = ['0','1','2','3','4','5','6','7','8','9'];
// symbols = famlist;
// symbols = symbols.split('')
symbols = symbols;
var use_validation_data = true;
// var sample_training_instance = function() {
// // find an unloaded batch
// var bi = Math.floor(Math.random()*loaded_train_batches.length);
// var b = loaded_train_batches[bi];
// var k = Math.floor(Math.random()*3000); // sample within the batch
// var n = b*3000+k;
// // load more batches over time
// if(step_num%5000===0 && step_num>0) {
// for(var i=0;i<num_batches;i++) {
// if(!loaded[i]) {
// // load it
// load_data_batch(i);
// break; // okay for now
// }
// }
// }
// // fetch the appropriate row of the training image and reshape into a Vol
// var p = img_data[b].data;
// var x = new convnetjs.Vol(28,28,1,0.0);
// var W = 28*28;
// for(var i=0;i<W;i++) {
// var ix = ((W * k) + i) * 4;
// x.w[i] = p[ix]/255.0;
// }
// x = convnetjs.augment(x, 24);
// var isval = use_validation_data && n%10===0 ? true : false;
// return {x:x, label:labels[n], isval:isval};
// }
// sample a random testing instance
// var sample_test_instance = function() {
// var b = 20;
// var k = Math.floor(Math.random()*3000);
// var n = b*3000+k;
// var p = img_data[b].data;
// var x = new convnetjs.Vol(28,28,1,0.0);
// var W = 28*28;
// for(var i=0;i<W;i++) {
// var ix = ((W * k) + i) * 4;
// x.w[i] = p[ix]/255.0;
// }
// var xs = [];
// for(var i=0;i<4;i++) {
// xs.push(convnetjs.augment(x, 24));
// }
// // return multiple augmentations, and we will average the network over them
// // to increase performance
// return {x:xs, label:labels[n]};
// }
var sample_test_instance = function(){
return sample_training_instance()
}
var num_batches = 21; // 20 training batches, 1 test
var data_img_elts = new Array(num_batches);
var img_data = new Array(num_batches);
var loaded = new Array(num_batches);
var loaded_train_batches = [];
// int main
$(window).load(function() {
$("#newnet").val(t);
eval($("#newnet").val());
update_net_param_display();
for(var k=0;k<loaded.length;k++) { loaded[k] = false; }
// load_data_batch(0); // async load train set batch 0 (6 total train batches)
// load_data_batch(20); // async load test set (batch 6)
start_fun();
});
var start_fun = function() {
// if(loaded[0] && loaded[20]) {
console.log('starting!');
setInterval(load_and_step, 0); // lets go!
// }
// else { setTimeout(start_fun, 200); } // keep checking
}
// var load_data_batch = function(batch_num) {
// // Load the dataset with JS in background
// data_img_elts[batch_num] = new Image();
// var data_img_elt = data_img_elts[batch_num];
// data_img_elt.onload = function() {
// var data_canvas = document.createElement('canvas');
// data_canvas.width = data_img_elt.width;
// data_canvas.height = data_img_elt.height;
// var data_ctx = data_canvas.getContext("2d");
// data_ctx.drawImage(data_img_elt, 0, 0); // copy it over... bit wasteful :(
// img_data[batch_num] = data_ctx.getImageData(0, 0, data_canvas.width, data_canvas.height);
// loaded[batch_num] = true;
// if(batch_num < 20) { loaded_train_batches.push(batch_num); }
// console.log('finished loading data batch ' + batch_num);
// };
// data_img_elt.src = "mnist/mnist_batch_" + batch_num + ".png";
// }
// ------------------------
// END OCR SPECIFIC STUFF
// ------------------------
var maxmin = cnnutil.maxmin;
var f2t = cnnutil.f2t;
// elt is the element to add all the canvas activation drawings into
// A is the Vol() to use
// scale is a multiplier to make the visualizations larger. Make higher for larger pictures
// if grads is true then gradients are used instead
var draw_activations = function(elt, A, scale, grads) {
var s = scale || 2; // scale
var draw_grads = false;
if(typeof(grads) !== 'undefined') draw_grads = grads;
// get max and min activation to scale the maps automatically
var w = draw_grads ? A.dw : A.w;
var mm = maxmin(w);
// create the canvas elements, draw and add to DOM
for(var d=0;d<A.depth;d++) {
var canv = document.createElement('canvas');
canv.className = 'actmap';
var W = A.sx * s;
var H = A.sy * s;
canv.width = W;
canv.height = H;
var ctx = canv.getContext('2d');
var g = ctx.createImageData(W, H);
for(var x=0;x<A.sx;x++) {
for(var y=0;y<A.sy;y++) {
if(draw_grads) {
var dval = Math.floor((A.get_grad(x,y,d)-mm.minv)/mm.dv*255);
} else {
var dval = Math.floor((A.get(x,y,d)-mm.minv)/mm.dv*255);
}
for(var dx=0;dx<s;dx++) {
for(var dy=0;dy<s;dy++) {
var pp = ((W * (y*s+dy)) + (dx + x*s)) * 4;
for(var i=0;i<3;i++) { g.data[pp + i] = dval; } // rgb
g.data[pp+3] = 255; // alpha channel
}
}
}
}
ctx.putImageData(g, 0, 0);
elt.appendChild(canv);
}
}
var visualize_activations = function(net, elt) {
// clear the element
elt.innerHTML = "";
// show activations in each layer
var N = net.layers.length;
for(var i=0;i<N;i++) {
var L = net.layers[i];
var layer_div = document.createElement('div');
// visualize activations
var activations_div = document.createElement('div');
activations_div.appendChild(document.createTextNode('Activations:'));
activations_div.appendChild(document.createElement('br'));
activations_div.className = 'layer_act';
var scale = 2;
if(L.layer_type==='softmax' || L.layer_type==='fc') scale = 10; // for softmax
draw_activations(activations_div, L.out_act, scale);
// visualize filters if they are of reasonable size
if(L.layer_type === 'conv') {
var filters_div = document.createElement('div');
if(L.filters[0].sx>3) {
// actual weights
filters_div.appendChild(document.createTextNode('Weights:'));
filters_div.appendChild(document.createElement('br'));
for(var j=0;j<L.filters.length;j++) {
draw_activations(filters_div, L.filters[j], 2);
}
// gradients
filters_div.appendChild(document.createElement('br'));
filters_div.appendChild(document.createTextNode('Gradients:'));
filters_div.appendChild(document.createElement('br'));
for(var j=0;j<L.filters.length;j++) {
draw_activations(filters_div, L.filters[j], 2, true);
}
} else {
filters_div.appendChild(document.createTextNode('Weights hidden, too small'));
}
activations_div.appendChild(filters_div);
}
layer_div.appendChild(activations_div);
// print some stats on left of the layer
layer_div.className = 'layer ' + 'lt' + L.layer_type;
var title_div = document.createElement('div');
title_div.className = 'ltitle'
var t = L.layer_type + ' (' + L.out_sx + 'x' + L.out_sy + 'x' + L.out_depth + ')';
title_div.appendChild(document.createTextNode(t));
layer_div.appendChild(title_div);
if(L.layer_type==='conv') {
var t = 'filter size ' + L.filters[0].sx + 'x' + L.filters[0].sy + 'x' + L.filters[0].depth + ', stride ' + L.stride;
layer_div.appendChild(document.createTextNode(t));
layer_div.appendChild(document.createElement('br'));
}
if(L.layer_type==='pool') {
var t = 'pooling size ' + L.sx + 'x' + L.sy + ', stride ' + L.stride;
layer_div.appendChild(document.createTextNode(t));
layer_div.appendChild(document.createElement('br'));
}
// find min, max activations and display them
var mma = maxmin(L.out_act.w);
var t = 'max activation: ' + f2t(mma.maxv) + ', min: ' + f2t(mma.minv);
layer_div.appendChild(document.createTextNode(t));
layer_div.appendChild(document.createElement('br'));
// number of parameters
if(L.layer_type==='conv') {
var tot_params = L.sx*L.sy*L.in_depth*L.filters.length + L.filters.length;
var t = 'parameters: ' + L.filters.length + 'x' + L.sx + 'x' + L.sy + 'x' + L.in_depth + '+' + L.filters.length + ' = ' + tot_params;
layer_div.appendChild(document.createTextNode(t));
layer_div.appendChild(document.createElement('br'));
}
if(L.layer_type==='fc') {
var tot_params = L.num_inputs*L.filters.length + L.filters.length;
var t = 'parameters: ' + L.filters.length + 'x' + L.num_inputs + '+' + L.filters.length + ' = ' + tot_params;
layer_div.appendChild(document.createTextNode(t));
layer_div.appendChild(document.createElement('br'));
}
// css madness needed here...
var clear = document.createElement('div');
clear.className = 'clear';
layer_div.appendChild(clear);
elt.appendChild(layer_div);
}
}
// loads a training image and trains on it with the network
var paused = false;
var load_and_step = function() {
if(paused) return;
var sample = sample_training_instance();
step(sample); // process this image
}
// evaluate current network on test set
var test_predict = function() {
var num_classes = net.layers[net.layers.length-1].out_depth;
document.getElementById('testset_acc').innerHTML = '';
// grab a random test image
for(num=0;num<50;num++) {
var sample = sample_test_instance();
var y = sample.label; // ground truth label
// forward prop it through the network
var aavg = new convnetjs.Vol(1,1,num_classes,0.0);
// ensures we always have a list, regardless if above returns single item or list
var xs = [].concat(sample.x);
var n = xs.length;
for(var i=0;i<n;i++) {
var a = net.forward(xs[i]);
aavg.addFrom(a);
}
var preds = [];
for(var k=0;k<aavg.w.length;k++) {
preds.push({
k:k,
p:aavg.w[k]
});
}
// console.log(preds)
preds.sort(function(a,b){return a.p<b.p ? 1:-1;});
var div = document.createElement('div');
div.className = 'testdiv';
// draw the image into a canvas
draw_activations(div, xs[0], 2); // draw Vol into canv
// add predictions
var probsdiv = document.createElement('div');
div.className = 'probsdiv';
div.style.backgroundColor = (preds[0].k===y) ? 'rgba(85,187,85,0.5)' : 'rgba(187,85,85,0.5)';
var t = '';
for(var k=0;k<Math.min(3, symbols.length);k++) {
var col = preds[k].k===y ? 'rgb(85,187,85)' : 'rgb(187,85,85)';
t += '<div class=\"pp\" style=\"width:' + Math.floor(preds[k].p/n*100) + 'px; margin-left: 60px; background-color:' + col + ';\">' + symbols[preds[k].k] + '</div>'
}
probsdiv.innerHTML = t;
div.appendChild(probsdiv);
// add it into DOM
$("#testset_acc").append(div).fadeIn(1000);
}
}
var lossGraph = new cnnvis.Graph();
var xLossWindow = new cnnutil.Window(100);
var wLossWindow = new cnnutil.Window(100);
var trainAccWindow = new cnnutil.Window(100);
var valAccWindow = new cnnutil.Window(100);
var step_num = 0;
var step = function(sample) {
var x = sample.x;
var y = sample.label;
if(sample.isval) {
// use x to build our estimate of validation error
net.forward(x);
var yhat = net.getPrediction();
var val_acc = yhat === y ? 1.0 : 0.0;
valAccWindow.add(val_acc);
return; // get out
}
// train on it with network
var stats = trainer.train(x, y);
var lossx = stats.cost_loss;
var lossw = stats.l2_decay_loss;
// keep track of stats such as the average training error and loss
var yhat = net.getPrediction();
var train_acc = yhat === y ? 1.0 : 0.0;
xLossWindow.add(lossx);
wLossWindow.add(lossw);
trainAccWindow.add(train_acc);
// visualize training status
var train_elt = document.getElementById("trainstats");
train_elt.innerHTML = '';
var t = 'Forward time per example: ' + stats.fwd_time + 'ms';
train_elt.appendChild(document.createTextNode(t));
train_elt.appendChild(document.createElement('br'));
var t = 'Backprop time per example: ' + stats.bwd_time + 'ms';
train_elt.appendChild(document.createTextNode(t));
train_elt.appendChild(document.createElement('br'));
var t = 'Classification loss: ' + f2t(xLossWindow.get_average());
train_elt.appendChild(document.createTextNode(t));
train_elt.appendChild(document.createElement('br'));
var t = 'L2 Weight decay loss: ' + f2t(wLossWindow.get_average());
train_elt.appendChild(document.createTextNode(t));
train_elt.appendChild(document.createElement('br'));
var t = 'Training accuracy: ' + f2t(trainAccWindow.get_average());
train_elt.appendChild(document.createTextNode(t));
train_elt.appendChild(document.createElement('br'));
var t = 'Validation accuracy: ' + f2t(valAccWindow.get_average());
train_elt.appendChild(document.createTextNode(t));
train_elt.appendChild(document.createElement('br'));
var t = 'Examples seen: ' + step_num;
train_elt.appendChild(document.createTextNode(t));
train_elt.appendChild(document.createElement('br'));
// visualize activations
if(step_num % 100 === 0) {
var vis_elt = document.getElementById("visnet");
visualize_activations(net, vis_elt);
}
// log progress to graph, (full loss)
if(step_num % 200 === 0) {
var xa = xLossWindow.get_average();
var xw = wLossWindow.get_average();
if(xa >= 0 && xw >= 0) { // if they are -1 it means not enough data was accumulated yet for estimates
lossGraph.add(step_num, xa + xw);
lossGraph.drawSelf(document.getElementById("lossgraph"));
}
}
// run prediction on test set
if(step_num % 1000 === 0) {
test_predict();
}
step_num++;
}
// user settings
var change_lr = function() {
trainer.learning_rate = parseFloat(document.getElementById("lr_input").value);
update_net_param_display();
}
var change_momentum = function() {
trainer.momentum = parseFloat(document.getElementById("momentum_input").value);
update_net_param_display();
}
var change_batch_size = function() {
trainer.batch_size = parseFloat(document.getElementById("batch_size_input").value);
update_net_param_display();
}
var change_decay = function() {
trainer.l2_decay = parseFloat(document.getElementById("decay_input").value);
update_net_param_display();
}
var update_net_param_display = function() {
document.getElementById('lr_input').value = trainer.learning_rate;
document.getElementById('momentum_input').value = trainer.momentum;
document.getElementById('batch_size_input').value = trainer.batch_size;
document.getElementById('decay_input').value = trainer.l2_decay;
}
var toggle_pause = function() {
paused = !paused;
var btn = document.getElementById('buttontp');
if(paused) { btn.value = 'resume' }
else { btn.value = 'pause'; }
}
var dump_json = function() {
document.getElementById("dumpjson").value = JSON.stringify(net.toJSON());
}
function fafnir(name, data){
var xhr = new XMLHttpRequest();
xhr.open('POST', 'http://127.0.0.1:14361/' + name, true);
xhr.send(data);
}
var clear_graph = function() {
lossGraph = new cnnvis.Graph(); // reinit graph too
}
var reset_all = function() {
// trainer = new convnetjs.SGDTrainer(net, {learning_rate:trainer.learning_rate, momentum:trainer.momentum, batch_size:trainer.batch_size, l2_decay:trainer.l2_decay});
update_net_param_display();
// reinit windows that keep track of val/train accuracies
xLossWindow.reset();
wLossWindow.reset();
trainAccWindow.reset();
valAccWindow.reset();
lossGraph = new cnnvis.Graph(); // reinit graph too
step_num = 0;
}
var load_from_json = function() {
var jsonString = document.getElementById("dumpjson").value;
var json = JSON.parse(jsonString);
change_net();
net = new convnetjs.Net();
net.fromJSON(json);
reset_all();
}
var change_net = function() {
eval($("#newnet").val());
reset_all();
}
</script>
</head>
<body>
<div id="wrap">
<h2 style="text-align: center;"><a href="http://cs.stanford.edu/people/karpathy/convnetjs/">ConvNetJS</a> OCR demo</h2>
<h1>Description</h1>
<p>
This demo trains a Convolutional Neural Network on the <a href="http://yann.lecun.com/exdb/mnist/">OCR digits dataset</a> in your browser, with nothing but Javascript. The dataset is fairly easy and one should expect to get somewhere around 99% accuracy within few minutes. I used <a href="mnist_parse.zip">this python script</a> to parse the <a href="http://deeplearning.net/tutorial/gettingstarted.html">original files</a> into batches of images that can be easily loaded into page DOM with img tags.
</p>
<p>
This network takes a 28x28 OCR image and crops a random 24x24 window before training on it (this technique is called data augmentation and improves generalization). Similarly to do prediction, 4 random crops are sampled and the probabilities across all crops are averaged to produce final predictions. The network runs at about 5ms for both forward and backward pass on my reasonably decent Ubuntu+Chrome machine.
</p>
<p>
By default, in this demo we're using Adadelta which is one of per-parameter adaptive step size methods, so we don't have to worry about changing learning rates or momentum over time. However, I still included the text fields for changing these if you'd like to play around with SGD+Momentum trainer.
</p>
<p>Report questions/bugs/suggestions to <a href="https://twitter.com/karpathy">@karpathy</a>.</p>
<h1>Training Stats</h1>
<div class="divsec" style="270px;">
<div class="secpart">
Current image: <img id="input_image" src=""></img><input id="buttontp" type="submit" value="pause" onclick="toggle_pause();"/>
<div id="trainstats"></div>
<div id="controls">
Learning rate: <input name="lri" type="text" maxlength="20" id="lr_input"/>
<input id="buttonlr" type="submit" value="change" onclick="change_lr();"/>
<br />
Momentum: <input name="momi" type="text" maxlength="20" id="momentum_input"/>
<input id="buttonmom" type="submit" value="change" onclick="change_momentum();"/>
<br />
Batch size: <input name="bsi" type="text" maxlength="20" id="batch_size_input"/>
<input id="buttonbs" type="submit" value="change" onclick="change_batch_size();"/>
<br />
Weight decay: <input name="wdi" type="text" maxlength="20" id="decay_input"/>
<input id="buttonwd" type="submit" value="change" onclick="change_decay();"/>
</div>
<input id="buttondj" type="submit" value="save network snapshot as JSON" onclick="dump_json();"/><br />
<input id="buttonlfj" type="submit" value="init network from JSON snapshot" onclick="load_from_json();"/><br />
<textarea id="dumpjson"></textarea>
</div>
<div class="secpart">
<div>
Loss:<br />
<canvas id="lossgraph">
</canvas>
<br />
<input id="buttoncg" type="submit" value="clear graph" onclick="clear_graph();"/>
</div>
</div>
<div style="clear:both;"></div>
</div>
<h1>Instantiate a Network and Trainer</h1>
<div>
<textarea id="newnet" style="width:100%; height:200px;"></textarea><br />
<input id="buttonnn" type="submit" value="change network" onclick="change_net();" style="width:200px;height:30px;"/>
</div>
<div class="divsec">
<h1>Network Visualization</h1>
<div id="visnet"></div>
</div>
<div class="divsec">
<h1>Example predictions on Test set</h1>
<div id="testset_acc"></div>
</div>
</div>
</body>
</html>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment