Skip to content

Instantly share code, notes, and snippets.

@jthomas
Last active September 24, 2023 21:58
  • Star 47 You must be signed in to star a gist
  • Fork 11 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save jthomas/145610bdeda2638d94fab9a397eb1f1d to your computer and use it in GitHub Desktop.
Using TensorFlow.js with MobileNet models for image classification on Node.js
{
"name": "tf-js",
"version": "1.0.0",
"main": "script.js",
"license": "MIT",
"dependencies": {
"@tensorflow-models/mobilenet": "^0.2.2",
"@tensorflow/tfjs": "^0.12.3",
"@tensorflow/tfjs-node": "^0.1.9",
"jpeg-js": "^0.3.4"
}
}
const tf = require('@tensorflow/tfjs')
const mobilenet = require('@tensorflow-models/mobilenet');
require('@tensorflow/tfjs-node')
const fs = require('fs');
const jpeg = require('jpeg-js');
const NUMBER_OF_CHANNELS = 3
const readImage = path => {
const buf = fs.readFileSync(path)
const pixels = jpeg.decode(buf, true)
return pixels
}
const imageByteArray = (image, numChannels) => {
const pixels = image.data
const numPixels = image.width * image.height;
const values = new Int32Array(numPixels * numChannels);
for (let i = 0; i < numPixels; i++) {
for (let channel = 0; channel < numChannels; ++channel) {
values[i * numChannels + channel] = pixels[i * 4 + channel];
}
}
return values
}
const imageToInput = (image, numChannels) => {
const values = imageByteArray(image, numChannels)
const outShape = [image.height, image.width, numChannels];
const input = tf.tensor3d(values, outShape, 'int32');
return input
}
const loadModel = async path => {
const mn = new mobilenet.MobileNet(1, 1);
mn.path = `file://${path}`
await mn.load()
return mn
}
const classify = async (model, path) => {
const image = readImage(path)
const input = imageToInput(image, NUMBER_OF_CHANNELS)
const mn_model = await loadModel(model)
const predictions = await mn_model.classify(input)
console.log('classification results:', predictions)
}
if (process.argv.length !== 4) throw new Error('incorrect arguments: node script.js <MODEL> <IMAGE_FILE>')
classify(process.argv[2], process.argv[3])
@MyIsaak
Copy link

MyIsaak commented Aug 12, 2019

tfjs-node already has an image decoding function for JPEG (and more) available at tf.node.decodeJpeg which avoid the need of jpeg-js and greatly simplify your code.

@tejas77
Copy link

tejas77 commented Oct 4, 2019

tfjs-node already has an image decoding function for JPEG (and more) available at tf.node.decodeJpeg which avoid the need of jpeg-js and greatly simplify your code.

@MyIsaak Thanks, made the code way to small 👍

@josephgoksu
Copy link

tfjs-node already has an image decoding function for JPEG (and more) available at tf.node.decodeJpeg which avoid the need of jpeg-js and greatly simplify your code.

I was trying to convert my app tfjs to tfjs-node.
tf.browser.fromPixel to tf.node.decodeImage
You saved my life. Thanks.

@o7g8
Copy link

o7g8 commented Aug 13, 2020

Hello,

Thank you everyone for the inspiration!
With all the input above, I've reduced the original code to (the MobileNet model will be downloaded):

const tfnode = require('@tensorflow/tfjs-node');
const mobilenet = require('@tensorflow-models/mobilenet');
const fs = require('fs');

const classify = async (imagePath) => {
  const image = fs.readFileSync(imagePath);
  const decodedImage = tfnode.node.decodeImage(image, 3);

  const model = await mobilenet.load();
  const predictions = await model.classify(decodedImage);
  console.log('predictions:', predictions);
}

if (process.argv.length !== 3) 
    throw new Error('Usage: node test-tf.js <image-file>')

classify(process.argv[2])

@ClementWalter
Copy link

Are you sure that the '@tensorflow-models/mobilenet' includes the mobilenet preprocessing?

@KaKi87
Copy link

KaKi87 commented Sep 24, 2023

@o7g8 Where are you specifying the path to model.json ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment