Skip to content

Instantly share code, notes, and snippets.

@Nithanaroy
Last active February 21, 2020 06:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Nithanaroy/9cb8e28e39c6c0c307348f4e8a79bf92 to your computer and use it in GitHub Desktop.
Save Nithanaroy/9cb8e28e39c6c0c307348f4e8a79bf92 to your computer and use it in GitHub Desktop.
Create a CNN Model in TFJS
importScripts("https://cdnjs.cloudflare.com/ajax/libs/tensorflow/1.3.2/tf.min.js", "data.js");
class VisionModelWorker {
constructor() {
this.model = null; // holds tfjs model
this.dataBunch = null; // holds the X and y datasets
}
create() {
this.model = tf.sequential();
this.model.add(tf.layers.conv2d({ inputShape: [28, 28, 1], kernelSize: 3, filters: 8, activation: 'relu' }));
this.model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
this.model.add(tf.layers.conv2d({ filters: 16, kernelSize: 3, activation: 'relu' }));
this.model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
this.model.add(tf.layers.flatten());
this.model.add(tf.layers.dense({ units: 128, activation: 'relu' }));
this.model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));
this.model.summary();
}
async getData(forceFetch = false) {
if (!!this.dataBunch && !forceFetch) {
return;
}
const numClasses = 10; // number of unique digits to classify
this.dataBunch = new Data();
await this.dataBunch.fetchDataAndSetupState(); // TODO: need a tf.tidy() around this
this.dataBunch.trainY = tf.oneHot(this.dataBunch.trainY, numClasses);
this.dataBunch.testY = tf.oneHot(this.dataBunch.testY, numClasses);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment