Created
October 29, 2020 16:59
-
-
Save N8python/5e447e5e6581404e1bfe8fac19df3c0a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const tf = require("@tensorflow/tfjs-node"); | |
const fs = require("fs"); | |
const { encode } = require("punycode"); | |
function oneHotEncode(num) { | |
const arr = Array(10).fill(0); | |
arr[num] = 1; | |
return arr; | |
} | |
const trainData = JSON.parse(fs.readFileSync("mnist_handwritten_train.json").toString()); | |
const testData = JSON.parse(fs.readFileSync("mnist_handwritten_test.json").toString()); | |
const trainTensorData = tf.tensor(trainData.map(sample => sample.image.map(x => x / 255))).reshape([trainData.length, 28, 28, 1]); | |
const trainTensorLabels = tf.tensor(trainData.map(sample => oneHotEncode(sample.label))); | |
const testTensorData = tf.tensor(testData.map(sample => sample.image.map(x => x / 255))).reshape([testData.length, 28, 28, 1]); | |
const testTensorLabels = tf.tensor(testData.map(sample => oneHotEncode(sample.label))); | |
const model = | |
/*tf.sequential({ | |
layers: [ | |
tf.layers.dense({ inputShape: [784], units: 32, activation: 'relu' }), | |
tf.layers.dropout({ rate: 0.1 }), | |
tf.layers.dense({ units: 32, activation: "relu" }), | |
tf.layers.dropout({ rate: 0.1 }), | |
tf.layers.dense({ units: 10, activation: 'softmax' }), | |
] | |
});*/ | |
tf.sequential({ | |
layers: [ | |
tf.layers.conv2d({ inputShape: [28, 28, 1], filters: 32, kernelSize: 3, activation: "relu" }), | |
tf.layers.conv2d({ filters: 32, kernelSize: 3, activation: "relu" }), | |
tf.layers.maxPool2d({ poolSize: [2, 2] }), | |
tf.layers.conv2d({ filters: 64, kernelSize: 3, activation: "relu" }), | |
tf.layers.conv2d({ filters: 64, kernelSize: 3, activation: "relu" }), | |
tf.layers.maxPool2d({ poolSize: [2, 2] }), | |
tf.layers.flatten(), | |
tf.layers.dropout({ rate: 0.25 }), | |
tf.layers.dense({ units: 512, activation: "relu" }), | |
tf.layers.dropout({ rate: 0.5 }), | |
tf.layers.dense({ units: 10, activation: "softmax" }) | |
] | |
}) | |
model.compile({ | |
optimizer: 'adam', | |
loss: 'categoricalCrossentropy', | |
metrics: ['accuracy'], | |
}); | |
model.fit(trainTensorData, trainTensorLabels, { | |
epochs: 10, | |
batchSize: 32, | |
callbacks: { | |
onTrainEnd(logs) { | |
(async() => { await model.save(`file://./conv-model`); })(); | |
} | |
} | |
}).then(info => { | |
console.log('Final accuracy', info.history.acc); | |
console.log('Test Accuracy:', tf.metrics.categoricalAccuracy(model.predict(testTensorData), testTensorLabels).dataSync().reduce((t, v) => t + v) / testData.length); | |
}); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment