Skip to content

Instantly share code, notes, and snippets.

@lucydjo
Last active October 13, 2019 11:30
Show Gist options
  • Save lucydjo/8244dc4d733d5a053cb92b4f3bc63773 to your computer and use it in GitHub Desktop.
Save lucydjo/8244dc4d733d5a053cb92b4f3bc63773 to your computer and use it in GitHub Desktop.
Digit recognition : Training Step
const fs = require('fs');
const db = require('node-localdb');
const csv = require('csvtojson');
const colors = require('colors');
const log = require('fancy-log');
const assert = require('assert');
const brain = require('brain.js');
const IterateObject = require("iterate-object");
const csvFilePath = 'digit_csv/train_slim.csv';
let train_data = [];
let train_data_formated = [];
let labels = [];
let contents = fs.readFileSync("training_data_output_2.json");
let jsonContent = JSON.parse(contents);
const net = new brain.NeuralNetwork()
net.fromJSON(jsonContent)
csv()
.fromFile(csvFilePath)
.on('done', (error) => {
log.info('Loading training data complete !');
generateTrainingObj();
})
.then((jsonObj) => {
let jsonlength = jsonObj.length;
for (k in jsonObj) {
train_data.push(jsonObj[k]);
console.log(k, jsonlength);
}
})
function generateTrainingObj() {
log.info('Construct training obj..');
for (k in train_data) {
let label = parseInt(train_data[k].label);
let input = [];
IterateObject(train_data[k], function(value, name) {
if (name != 'label') {
input.push( parseFloat(value) / 255 );
}
});
let ii = 0;
let input_obj = input.reduce(function(o, val) {
o['pixel'+ii] = val;
ii++;
return o;
}, {});
labels.push(label);
train_data_formated.push({ input: input_obj });
if (parseInt(k) + 1 == train_data.length) {
startTesting()
}
}
}
function startTesting() {
log.info('Start testing !');
for(let i = 0; i < train_data_formated.length; i++) {
const output = net.run( train_data_formated[i].input );
// Good result
console.log(output, labels[i]);
}
}
const fs = require('fs');
const db = require('node-localdb');
const csv = require('csvtojson');
const colors = require('colors');
const log = require('fancy-log');
const assert = require('assert');
const brain = require('brain.js');
const IterateObject = require("iterate-object");
/* Contain pixel color 0 - 255 */
const csvFilePath = 'digit_csv/train.csv';
let train_data = [];
let train_data_formated = [];
csv()
.fromFile(csvFilePath)
.on('done',(error)=>{
log.info('Loading training data complete !');
generateTrainingObj();
})
.then( (jsonObj) => {
let jsonlength = jsonObj.length;
for(k in jsonObj) {
if(k < 2000) {
train_data.push( jsonObj[k] );
console.log(k, jsonlength);
}
}
})
function generateTrainingObj() {
log.info('Construct training obj..');
for(k in train_data) {
let label = parseInt( train_data[k].label );
let input = [];
IterateObject(train_data[k], function (value, name) {
if(name != 'label') {
input.push( parseFloat(value) / 255 );
}
});
let ii = 0;
let input_obj = input.reduce(function(o, val) {
o['pixel'+ii] = val;
ii++;
return o;
}, {});
let output = null;
if( label == 0) {
output = { N0: 1 }
};
if( label == 1) {
output = { N1: 1 }
};
if( label == 2) {
output = { N2: 1 }
};
if( label == 3) {
output = { N3: 1 }
};
if( label == 4) {
output = { N4: 1 }
};
if( label == 5) {
output = { N5: 1 }
};
if( label == 6) {
output = { N6: 1 }
};
if( label == 7) {
output = { N7: 1 }
};
if( label == 8) {
output = { N8: 1 }
};
if( label == 9) {
output = { N9: 1 }
};
train_data_formated.push( { input: input_obj, output: output } );
console.log(parseInt(k)+1 , train_data.length)
if(parseInt(k)+1 == train_data.length) {
startTraining()
}
}
}
/* Note Used */
const config = {
hiddenLayers: [300, 200, 100, 50], // array of ints for the sizes of the hidden layers in the network
activation: 'sigmoid', // supported activation types: ['sigmoid', 'relu', 'leaky-relu', 'tanh'],
};
function startTraining() {
log.info('Start training !');
const net = new brain.NeuralNetworkGPU();
net.train(train_data_formated, {
logPeriod: 1,
iterations: 20,
log: detail => log.info(detail)
});
setTimeout(function(){
log.info('SAVE TRAINING DATA');
const jsonOutput = JSON.stringify( net.toJSON() );
fs.writeFile('training_data_output_2.json', jsonOutput, 'utf8', function(){
log.info('SAVING DONE');
});
}, 15000);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment