Created
March 8, 2019 20:36
-
-
Save kedevked/4ab3c0f3e76c6f585f1cc0d3d8527168 to your computer and use it in GitHub Desktop.
using mobilenet for mnist classification
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const tf = require('@tensorflow/tfjs'); | |
require('@tensorflow/tfjs-node'); | |
const mnist = require('mnist'); | |
global.fetch = require('node-fetch') | |
const NUM_CLASSES = 10; | |
async function loadModel() { | |
const loadedModel = await tf.loadModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json') | |
// take whatever layer except last output | |
loadedModel.layers.forEach(layer => console.log(layer.name)) | |
const layer = loadedModel.getLayer('conv_pw_13_relu') | |
return tf.model({ inputs: loadedModel.inputs, outputs: layer.output }); | |
} | |
/* loadModel().then(featureExtractor => { | |
})*/ | |
async function buildModel(featureExtractor, units) { | |
return tf.sequential({ | |
layers: [ | |
// Flattens the input to a vector so we can use it in a dense layer. While | |
// technically a layer, this only performs a reshape (and has no training | |
// parameters). | |
// slice so as not to take the batch size | |
tf.layers.flatten( | |
{ inputShape: featureExtractor.outputs[0].shape.slice(1) }), | |
// add all the layers of the model to train | |
tf.layers.dense({ | |
units: units, | |
activation: 'relu', | |
kernelInitializer: 'varianceScaling', | |
useBias: true | |
}), | |
// Layer 2. The number of units of the last layer should correspond | |
// to the number of classes we want to predict. | |
tf.layers.dense({ | |
units: NUM_CLASSES, | |
kernelInitializer: 'varianceScaling', | |
useBias: false, | |
activation: 'softmax' | |
}) | |
] | |
}); | |
} | |
async function train(model, inputs, outputs, featureExtractor) { | |
console.log('training') | |
const config = { | |
shuffle: true, | |
epochs: 1000, | |
batchSize: 100, | |
callbacks: { | |
onEpochEnd: async (_, l) => { console.log(l.loss) } | |
} | |
}; | |
model.summary() | |
training_labels = tf.tensor2d(outputs); | |
training_features = featureExtractor.predict(mobilenetPredict(inputs)); | |
const response = await model.fit(training_features, training_labels, config); | |
} | |
function mobilenetPredict(inputs) { | |
tensors_inputs = inputs.map(e => { | |
return tf.tensor(e).reshape([28, 28, 1]).resizeBilinear([224, 224]).tile([1, 1, 3]) | |
}) | |
console.log(tensors_inputs.shape) | |
return tf.stack(tensors_inputs, 0) | |
} | |
(async () => { | |
const mobilenet = await loadModel() | |
const model = await buildModel(mobilenet, 5) | |
model.compile({ | |
optimizer: tf.train.sgd(0.001), | |
loss: 'categoricalCrossentropy', | |
metrics: ['accuracy'], | |
}); | |
let coolSet = mnist.set(100, 300); | |
let inputs = []; | |
let outputs = []; | |
coolSet.training.forEach((oneTraining, index) => { | |
inputs.push(oneTraining.input); | |
outputs.push(oneTraining.output) | |
}); | |
// outputs = tf.tensor2d(outputs); | |
// inputs = tf.tensor2d(inputs); | |
let testInputs = []; | |
let testOutputs = []; | |
coolSet.test.forEach(oneTest => { | |
testInputs.push(oneTest.input); | |
testOutputs.push(oneTest.output) | |
}); | |
train(model, inputs, outputs, mobilenet).then(() => { | |
testInputs.forEach((x, index) => { | |
const predictedOutput = model.predict(mobilenet.predict(mobilenetPredict([x]))); | |
console.log(`Expected Output: ${testOutputs[index]} | |
Output: ${predictedOutput.toString()}`) | |
}); | |
}); | |
})() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment