Skip to content

Instantly share code, notes, and snippets.

@Nithanaroy
Created February 21, 2020 06:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Nithanaroy/b61ff5f6575f28cf6a55c011bb70cb9c to your computer and use it in GitHub Desktop.
Save Nithanaroy/b61ff5f6575f28cf6a55c011bb70cb9c to your computer and use it in GitHub Desktop.
Train the CNN Model in TFJS
class VisionModelWorker {
... // create and getData functions
async train(epochs, batchSize) {
this.model.compile({ optimizer: tf.train.adam(), loss: 'categoricalCrossentropy', metrics: ['accuracy'] });
const historyObj = await this.model.fit(this.dataBunch.trainX, this.dataBunch.trainY, {
batchSize: batchSize,
validationData: [this.dataBunch.testX, this.dataBunch.testY],
epochs: epochs,
shuffle: true
});
this.trainingHistories.push(historyObj);
console.debug(`Training history: ${JSON.stringify(historyObj.history)}`);
return historyObj;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment