Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
function predict(model, data, testDataSize = 500) {
const testData = data.nextDataBatch(testDataSize, true);
const testxs = testData.xs.reshape([testDataSize, 28, 28, 1]);
const labels = testData.labels.argMax([-1]);
const preds = model.predict(testxs).argMax([-1]);
return [preds, labels];
async function displayAccuracyPerClass(model, data) {
const [preds, labels] = predict(model, data);
const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
const container = {name: 'Accuracy', tab: 'Evaluation'};, classAccuracy, classNames);
async function displayConfusionMatrix(model, data) {
const [preds, labels] = predict(model, data);
const confusionMatrix = await tfvis.metrics.confusionMatrix(labels, preds);
const container = {name: 'Confusion Matrix', tab: 'Evaluation'};
container, {values: confusionMatrix}, classNames);
async function evaluateModelFunction(model, data)
await displayAccuracyPerClass(model, data);
await displayConfusionMatrix(model, data);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment