Last active
October 13, 2019 11:30
-
-
Save lucydjo/8244dc4d733d5a053cb92b4f3bc63773 to your computer and use it in GitHub Desktop.
Digit recognition : Training Step
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const fs = require('fs'); | |
const db = require('node-localdb'); | |
const csv = require('csvtojson'); | |
const colors = require('colors'); | |
const log = require('fancy-log'); | |
const assert = require('assert'); | |
const brain = require('brain.js'); | |
const IterateObject = require("iterate-object"); | |
const csvFilePath = 'digit_csv/train_slim.csv'; | |
let train_data = []; | |
let train_data_formated = []; | |
let labels = []; | |
let contents = fs.readFileSync("training_data_output_2.json"); | |
let jsonContent = JSON.parse(contents); | |
const net = new brain.NeuralNetwork() | |
net.fromJSON(jsonContent) | |
csv() | |
.fromFile(csvFilePath) | |
.on('done', (error) => { | |
log.info('Loading training data complete !'); | |
generateTrainingObj(); | |
}) | |
.then((jsonObj) => { | |
let jsonlength = jsonObj.length; | |
for (k in jsonObj) { | |
train_data.push(jsonObj[k]); | |
console.log(k, jsonlength); | |
} | |
}) | |
function generateTrainingObj() { | |
log.info('Construct training obj..'); | |
for (k in train_data) { | |
let label = parseInt(train_data[k].label); | |
let input = []; | |
IterateObject(train_data[k], function(value, name) { | |
if (name != 'label') { | |
input.push( parseFloat(value) / 255 ); | |
} | |
}); | |
let ii = 0; | |
let input_obj = input.reduce(function(o, val) { | |
o['pixel'+ii] = val; | |
ii++; | |
return o; | |
}, {}); | |
labels.push(label); | |
train_data_formated.push({ input: input_obj }); | |
if (parseInt(k) + 1 == train_data.length) { | |
startTesting() | |
} | |
} | |
} | |
function startTesting() { | |
log.info('Start testing !'); | |
for(let i = 0; i < train_data_formated.length; i++) { | |
const output = net.run( train_data_formated[i].input ); | |
// Good result | |
console.log(output, labels[i]); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const fs = require('fs'); | |
const db = require('node-localdb'); | |
const csv = require('csvtojson'); | |
const colors = require('colors'); | |
const log = require('fancy-log'); | |
const assert = require('assert'); | |
const brain = require('brain.js'); | |
const IterateObject = require("iterate-object"); | |
/* Contain pixel color 0 - 255 */ | |
const csvFilePath = 'digit_csv/train.csv'; | |
let train_data = []; | |
let train_data_formated = []; | |
csv() | |
.fromFile(csvFilePath) | |
.on('done',(error)=>{ | |
log.info('Loading training data complete !'); | |
generateTrainingObj(); | |
}) | |
.then( (jsonObj) => { | |
let jsonlength = jsonObj.length; | |
for(k in jsonObj) { | |
if(k < 2000) { | |
train_data.push( jsonObj[k] ); | |
console.log(k, jsonlength); | |
} | |
} | |
}) | |
function generateTrainingObj() { | |
log.info('Construct training obj..'); | |
for(k in train_data) { | |
let label = parseInt( train_data[k].label ); | |
let input = []; | |
IterateObject(train_data[k], function (value, name) { | |
if(name != 'label') { | |
input.push( parseFloat(value) / 255 ); | |
} | |
}); | |
let ii = 0; | |
let input_obj = input.reduce(function(o, val) { | |
o['pixel'+ii] = val; | |
ii++; | |
return o; | |
}, {}); | |
let output = null; | |
if( label == 0) { | |
output = { N0: 1 } | |
}; | |
if( label == 1) { | |
output = { N1: 1 } | |
}; | |
if( label == 2) { | |
output = { N2: 1 } | |
}; | |
if( label == 3) { | |
output = { N3: 1 } | |
}; | |
if( label == 4) { | |
output = { N4: 1 } | |
}; | |
if( label == 5) { | |
output = { N5: 1 } | |
}; | |
if( label == 6) { | |
output = { N6: 1 } | |
}; | |
if( label == 7) { | |
output = { N7: 1 } | |
}; | |
if( label == 8) { | |
output = { N8: 1 } | |
}; | |
if( label == 9) { | |
output = { N9: 1 } | |
}; | |
train_data_formated.push( { input: input_obj, output: output } ); | |
console.log(parseInt(k)+1 , train_data.length) | |
if(parseInt(k)+1 == train_data.length) { | |
startTraining() | |
} | |
} | |
} | |
/* Note Used */ | |
const config = { | |
hiddenLayers: [300, 200, 100, 50], // array of ints for the sizes of the hidden layers in the network | |
activation: 'sigmoid', // supported activation types: ['sigmoid', 'relu', 'leaky-relu', 'tanh'], | |
}; | |
function startTraining() { | |
log.info('Start training !'); | |
const net = new brain.NeuralNetworkGPU(); | |
net.train(train_data_formated, { | |
logPeriod: 1, | |
iterations: 20, | |
log: detail => log.info(detail) | |
}); | |
setTimeout(function(){ | |
log.info('SAVE TRAINING DATA'); | |
const jsonOutput = JSON.stringify( net.toJSON() ); | |
fs.writeFile('training_data_output_2.json', jsonOutput, 'utf8', function(){ | |
log.info('SAVING DONE'); | |
}); | |
}, 15000); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment