Skip to content

Instantly share code, notes, and snippets.

@kedevked
Created March 8, 2019 20:36
Show Gist options
  • Save kedevked/4ab3c0f3e76c6f585f1cc0d3d8527168 to your computer and use it in GitHub Desktop.
Save kedevked/4ab3c0f3e76c6f585f1cc0d3d8527168 to your computer and use it in GitHub Desktop.
using mobilenet for mnist classification
const tf = require('@tensorflow/tfjs');
require('@tensorflow/tfjs-node');
const mnist = require('mnist');
global.fetch = require('node-fetch')
const NUM_CLASSES = 10;
async function loadModel() {
const loadedModel = await tf.loadModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json')
// take whatever layer except last output
loadedModel.layers.forEach(layer => console.log(layer.name))
const layer = loadedModel.getLayer('conv_pw_13_relu')
return tf.model({ inputs: loadedModel.inputs, outputs: layer.output });
}
/* loadModel().then(featureExtractor => {
})*/
async function buildModel(featureExtractor, units) {
return tf.sequential({
layers: [
// Flattens the input to a vector so we can use it in a dense layer. While
// technically a layer, this only performs a reshape (and has no training
// parameters).
// slice so as not to take the batch size
tf.layers.flatten(
{ inputShape: featureExtractor.outputs[0].shape.slice(1) }),
// add all the layers of the model to train
tf.layers.dense({
units: units,
activation: 'relu',
kernelInitializer: 'varianceScaling',
useBias: true
}),
// Layer 2. The number of units of the last layer should correspond
// to the number of classes we want to predict.
tf.layers.dense({
units: NUM_CLASSES,
kernelInitializer: 'varianceScaling',
useBias: false,
activation: 'softmax'
})
]
});
}
async function train(model, inputs, outputs, featureExtractor) {
console.log('training')
const config = {
shuffle: true,
epochs: 1000,
batchSize: 100,
callbacks: {
onEpochEnd: async (_, l) => { console.log(l.loss) }
}
};
model.summary()
training_labels = tf.tensor2d(outputs);
training_features = featureExtractor.predict(mobilenetPredict(inputs));
const response = await model.fit(training_features, training_labels, config);
}
function mobilenetPredict(inputs) {
tensors_inputs = inputs.map(e => {
return tf.tensor(e).reshape([28, 28, 1]).resizeBilinear([224, 224]).tile([1, 1, 3])
})
console.log(tensors_inputs.shape)
return tf.stack(tensors_inputs, 0)
}
(async () => {
const mobilenet = await loadModel()
const model = await buildModel(mobilenet, 5)
model.compile({
optimizer: tf.train.sgd(0.001),
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
let coolSet = mnist.set(100, 300);
let inputs = [];
let outputs = [];
coolSet.training.forEach((oneTraining, index) => {
inputs.push(oneTraining.input);
outputs.push(oneTraining.output)
});
// outputs = tf.tensor2d(outputs);
// inputs = tf.tensor2d(inputs);
let testInputs = [];
let testOutputs = [];
coolSet.test.forEach(oneTest => {
testInputs.push(oneTest.input);
testOutputs.push(oneTest.output)
});
train(model, inputs, outputs, mobilenet).then(() => {
testInputs.forEach((x, index) => {
const predictedOutput = model.predict(mobilenet.predict(mobilenetPredict([x])));
console.log(`Expected Output: ${testOutputs[index]}
Output: ${predictedOutput.toString()}`)
});
});
})()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment