Skip to content

Instantly share code, notes, and snippets.

@YankeeTube
Last active October 22, 2022 08:18
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YankeeTube/ee96f60f57b9038ee0b703fc6620e7d9 to your computer and use it in GitHub Desktop.
Save YankeeTube/ee96f60f57b9038ee0b703fc6620e7d9 to your computer and use it in GitHub Desktop.
So Very Fast NSFWJS on TFJS WASM + Web Worker
<html>
<head></head>
<body>
<div>
<input type="file" id="file-input" />
</div>
</body>
<script>
const worker = new Worker('worker.js');
worker.postMessage('init')
worker.addEventListener('message', nsfwResult)
document.addEventListener('DOMContentLoaded', () => {
const fileInput = document.querySelector('#file-input');
fileInput.addEventListener('change', imageHandler);
});
function nsfwResult({data}) {
console.log(data)
}
async function imageHandler(e) {
const file = e.target.files[0];
worker.postMessage(file);
}
</script>
</html>
// download group1-shard1of1, model.json
// https://github.com/infinitered/nsfwjs/tree/master/example/nsfw_demo/public/quant_nsfw_mobilenet
importScripts("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js");
importScripts("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-backend-wasm/dist/tf-backend-wasm.js");
tf.wasm.setWasmPaths("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-backend-wasm/wasm-out/");
tf.enableProdMode()
let model;
const SIZE = 224;
const NSFW_CLASSES = {
0: 'Drawing',
1: 'Hentai',
2: 'Neutral',
3: 'Porn',
4: 'Sexy'
}
function nsfwProcess(values) {
const topK = 5;
const result = {}
const valuesAndIndices = [];
const topkValues = new Float32Array(topK);
const topkIndices = new Int32Array(topK);
for (let i = 0; i < values.length; i++) {
valuesAndIndices.push({ value: values[i], index: i });
}
valuesAndIndices.sort((a, b) => b.value - a.value);
for (let i = 0; i < topK; i++) {
topkValues[i] = valuesAndIndices[i].value;
topkIndices[i] = valuesAndIndices[i].index;
}
for (let i=0;i<5;i++) {
result[NSFW_CLASSES[[topkIndices[i]]]] = Number.parseFloat((topkValues[i] * 100).toFixed(2))
}
return result;
}
async function detectNSFW(bitmap) {
const {width: w, height: h} = bitmap;
const offScreen = new OffscreenCanvas(w,h);
const ctx = offScreen.getContext('2d');
ctx.drawImage(bitmap, 0, 0, w, h);
const canvasData = ctx.getImageData(0, 0, w,h).data;
const img = new ImageData(canvasData, w, h);
const pixels = tf.browser.fromPixels(img);
const normalized = pixels.toFloat().div(tf.scalar(255));
let resized = normalized;
if (pixels.shape[0] !== SIZE || pixels.shape[1] !== SIZE) {
resized = tf.image.resizeBilinear(normalized, [SIZE, SIZE], true);
}
const batched = resized.reshape([1, SIZE, SIZE, 3]);
const predictions = await model.predict(batched);
const values = await predictions.data();
const result = nsfwProcess(values);
predictions.dispose();
console.log(result);
self.postMessage(result);
}
async function init({data}) {
if (typeof data === 'string' && data === 'init') {
await tf.setBackend('wasm');
try {
model = await tf.loadLayersModel('indexeddb://model');
console.log('Load NSFW Model!');
} catch(e) {
model = await tf.loadLayersModel('models/model.json');
model.save('indexeddb://model');
console.log('Save NSFW Model!');
} finally {
// warm up
const result = tf.tidy(() => model.predict(tf.zeros([1, SIZE, SIZE, 3])));
await result.data();
result.dispose();
}
return
}
const bitmap = await createImageBitmap(data);
detectNSFW(bitmap);
}
addEventListener('message', init)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment