Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created March 24, 2019 13:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NMZivkovic/e384d79028a4c97a1d5ee95905b9c388 to your computer and use it in GitHub Desktop.
Save NMZivkovic/e384d79028a4c97a1d5ee95905b9c388 to your computer and use it in GitHub Desktop.
async function trainModelFunction(model, data, epochs) {
const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
const container = {
name: 'Model Training', styles: { height: '1000px' }
};
const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);
const batchSize = 512;
const [trainX, trainY] = getBatch(data, 5500);
const [testX, testY] = getBatch(data, 1000, true);
return model.fit(trainX, trainY, {
batchSize: batchSize,
validationData: [testX, testY],
epochs: epochs,
shuffle: true,
callbacks: fitCallbacks
});
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment