Skip to content

Instantly share code, notes, and snippets.

@N8python
Created October 29, 2020 17:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save N8python/7cc0f3c07d049c28c8321b55befb7fdf to your computer and use it in GitHub Desktop.
Save N8python/7cc0f3c07d049c28c8321b55befb7fdf to your computer and use it in GitHub Desktop.
const tf = require("@tensorflow/tfjs-node");
const fs = require("fs");
const mnistToImage = require("./mnist-image.js");
function oneHotEncode(num) {
const arr = Array(10).fill(0);
arr[num] = 1;
return arr;
}
const trainData = JSON.parse(fs.readFileSync("mnist_handwritten_train.json").toString());
const testData = JSON.parse(fs.readFileSync("mnist_handwritten_test.json").toString());
const trainTensorData = tf.tensor(trainData.map(sample => sample.image.map(x => x / 255)));
const trainTensorLabels = tf.tensor(trainData.map(sample => oneHotEncode(sample.label)));
const testTensorData = tf.tensor(testData.map(sample => sample.image.map(x => x / 255)));
const testTensorLabels = tf.tensor(testData.map(sample => oneHotEncode(sample.label)));
const model = tf.sequential({
layers: [
tf.layers.dense({ inputShape: [784], units: 196, activation: 'tanh' }),
tf.layers.dense({ units: 49, activation: "tanh" }),
tf.layers.dense({ units: 10, activation: "tanh" }),
tf.layers.dense({ units: 2, activation: "tanh" }),
tf.layers.dense({ units: 10, activation: "tanh" }),
tf.layers.dense({ units: 49, activation: "tanh" }),
tf.layers.dense({ units: 196, activation: "tanh" }),
tf.layers.dense({ units: 784, activation: 'sigmoid' }),
]
});
model.compile({
optimizer: 'adam',
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
let epoch = 0;
model.fit(trainTensorData, trainTensorData, {
epochs: 50,
batchSize: 32,
callbacks: {
onEpochEnd() {
epoch++;
const image = testData[Math.floor(Math.random() * testData.length)].image;
mnistToImage(image.map(x => x * 255)).pack().pipe(fs.createWriteStream(`outs/true${epoch}.png`));
mnistToImage(model.predict(tf.tensor([image])).dataSync().map(x => x * 255)).pack().pipe(fs.createWriteStream(`outs/out${epoch}.png`));
},
onTrainEnd(logs) {
(async() => { await model.save(`file://./encoder-model`); })();
}
}
}).then(info => {
console.log('Final accuracy', info.history.acc);
console.log('Test Accuracy:', tf.metrics.categoricalAccuracy(model.predict(testTensorData), testTensorLabels).dataSync().reduce((t, v) => t + v) / testData.length);
});
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment