Skip to content

Instantly share code, notes, and snippets.

@Nithanaroy
Created February 21, 2020 07:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Nithanaroy/fc3aebab154fec76d440076bcf88a585 to your computer and use it in GitHub Desktop.
Save Nithanaroy/fc3aebab154fec76d440076bcf88a585 to your computer and use it in GitHub Desktop.
Visualize Loss and Accuracy curves using Comlink
// In the main thread
class VisionModel {
/**
* Creates an instance of a wrapper of Vision Model that resides in the worker thread
* This has access to the DOM and can show visualizations, unlike the worker
* @param {HTMLElement} tensorboardDiv Instance of a Div tag where to show live model training
*/
constructor(tensorboardDiv) {
this.tensorboardDiv = tensorboardDiv;
}
async init() {
const VisionModelWorker = Comlink.wrap(new Worker("model-worker.js"));
this.visionModelWorker = await new VisionModelWorker();
}
async run({ batchSize = 1024, epochs = 1 } = {}) {
if (!this.visionModelWorker) { await this.init(); }
const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
// To visualize training stats, we pass function pointers using Comlink's callback mechanism
const vizCallbacks = tfvis.show.fitCallbacks(this.tensorboardDiv, metrics);
// Note: Comlink doesn't work with JS named arguments https://github.com/GoogleChromeLabs/comlink/issues/420
return this.visionModelWorker.run(batchSize, epochs, trainExisting, Comlink.proxy(vizCallbacks.onBatchEnd), Comlink.proxy(vizCallbacks.onEpochEnd));
}
}
// In the worker thread
class VisionModelWorker {
// ... getData(), create() methods like before
async train(epochs, batchSize, vizCallbacks) {
const historyObj = await this.model.fit(this.dataBunch.trainX, this.dataBunch.trainY, {
// ... params like before
callbacks: vizCallbacks
});
// ...
}
async run(batchSize = 1024, epochs = 1, trainExisting = true, onBatchEndCb = null, onEpochEndCb = null) {
if (!this.model || !trainExisting) {
await this.create();
}
await this.getData();
const vizCallbacks = {
onBatchEnd: onBatchEndCb,
onEpochEnd: onEpochEndCb
}
return this.train(epochs, batchSize, vizCallbacks);
}
}
Comlink.expose(VisionModelWorker);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment