Created
October 29, 2020 17:01
-
-
Save N8python/7cc0f3c07d049c28c8321b55befb7fdf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const tf = require("@tensorflow/tfjs-node"); | |
const fs = require("fs"); | |
const mnistToImage = require("./mnist-image.js"); | |
function oneHotEncode(num) { | |
const arr = Array(10).fill(0); | |
arr[num] = 1; | |
return arr; | |
} | |
const trainData = JSON.parse(fs.readFileSync("mnist_handwritten_train.json").toString()); | |
const testData = JSON.parse(fs.readFileSync("mnist_handwritten_test.json").toString()); | |
const trainTensorData = tf.tensor(trainData.map(sample => sample.image.map(x => x / 255))); | |
const trainTensorLabels = tf.tensor(trainData.map(sample => oneHotEncode(sample.label))); | |
const testTensorData = tf.tensor(testData.map(sample => sample.image.map(x => x / 255))); | |
const testTensorLabels = tf.tensor(testData.map(sample => oneHotEncode(sample.label))); | |
const model = tf.sequential({ | |
layers: [ | |
tf.layers.dense({ inputShape: [784], units: 196, activation: 'tanh' }), | |
tf.layers.dense({ units: 49, activation: "tanh" }), | |
tf.layers.dense({ units: 10, activation: "tanh" }), | |
tf.layers.dense({ units: 2, activation: "tanh" }), | |
tf.layers.dense({ units: 10, activation: "tanh" }), | |
tf.layers.dense({ units: 49, activation: "tanh" }), | |
tf.layers.dense({ units: 196, activation: "tanh" }), | |
tf.layers.dense({ units: 784, activation: 'sigmoid' }), | |
] | |
}); | |
model.compile({ | |
optimizer: 'adam', | |
loss: 'categoricalCrossentropy', | |
metrics: ['accuracy'], | |
}); | |
let epoch = 0; | |
model.fit(trainTensorData, trainTensorData, { | |
epochs: 50, | |
batchSize: 32, | |
callbacks: { | |
onEpochEnd() { | |
epoch++; | |
const image = testData[Math.floor(Math.random() * testData.length)].image; | |
mnistToImage(image.map(x => x * 255)).pack().pipe(fs.createWriteStream(`outs/true${epoch}.png`)); | |
mnistToImage(model.predict(tf.tensor([image])).dataSync().map(x => x * 255)).pack().pipe(fs.createWriteStream(`outs/out${epoch}.png`)); | |
}, | |
onTrainEnd(logs) { | |
(async() => { await model.save(`file://./encoder-model`); })(); | |
} | |
} | |
}).then(info => { | |
console.log('Final accuracy', info.history.acc); | |
console.log('Test Accuracy:', tf.metrics.categoricalAccuracy(model.predict(testTensorData), testTensorLabels).dataSync().reduce((t, v) => t + v) / testData.length); | |
}); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment