Skip to content

Instantly share code, notes, and snippets.

@dangerdak
Last active April 16, 2018 11:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dangerdak/3edcdfe6b96aea8b7363b1d571c74147 to your computer and use it in GitHub Desktop.
Save dangerdak/3edcdfe6b96aea8b7363b1d571c74147 to your computer and use it in GitHub Desktop.
PropelML: run trained model
const fs = require('fs');
const path = require('path');
const pr = require('propel'); // Uncomment if you installed using npm
// const pr = require('./src/api'); // Comment out if you installed using npm
run();
async function run() {
// These are the default values used in example.js
const experimentDir = 'exp001';
const layers = ['L1', 'L2', 'L3'];
const params = await loadParams(experimentDir, layers);
const testData = pr.dataset("mnist/test").batch(128);
// would be better to use tensors instead of arrays of tensors
const allPredictions = [];
const allLabels = [];
const accuracies = [];
for (const batchPromise of testData) {
const { images, labels } = await batchPromise;
allLabels.push(labels);
const batchPrediction = images.rescale([0, 255], [-1, 1])
.linear("L1", params, 200).relu()
.linear("L2", params, 100).relu()
.linear("L3", params, 10)
.softmax();
// batchPrediction is [128, 10] ie for each image it contains an array of
// probabilities that it belongs in each class
allPredictions.push(batchPrediction);
accuracies.push(batchAccuracy(batchPrediction, labels));
}
// Find accuracy accross all batches
const accuracy = accuracies.reduce((total, accuracy) => total.add(accuracy), pr.tensor(0)) / accuracies.length;
console.log('Accuracy:', accuracy);
logProbabilityExample(allPredictions, allLabels);
//logSample(allPredictions, allLabels);
}
// Returns parameters loaded from latest checkpoint
async function loadParams(expDir, layers) {
// relative paths dont work due to propelml bug
expDir = path.join(process.env.HOME, '.propel', expDir);
const params = pr.params();
// load contents from npy files into a tensors and then into params
for (const layer of layers) {
for (const paramType of ['bias', 'weights']) {
let file = path.join(expDir, latestCheckpoint(expDir), layer, `${paramType}.npy`);
let tensor = await pr.load(file);
params.define(`${layer}/${paramType}`, () => tensor);
}
}
return params;
}
// Returns accuracy for a given batch of predictions and labels
function batchAccuracy(predictionBatch, labelBatch) {
// argmax finds the index of the largest value along the provided axis of
// a tensor - we take the index where the probability is the highest
// to represent the prediction of the model and compare this to the
// corresponding label, resulting in a tensor of 1s and 0s where the
// prediction matches or doesn't match respectively. The mean of this tensor
// is the accuracy of the model for the sample provided (the batch)
return predictionBatch.argmax(1).equal(labelBatch).reduceMean(0);
}
// Returns dirname of most recent checkpoint within provided experiment directory
function latestCheckpoint(expDir) {
const checkpoints = fs.readdirSync(expDir);
return checkpoints.sort((a, b) => b - a)[0];
}
// Logs a sample of predicted/inferred digits and their actual labels
function logSample(predictions, labels) {
console.log(' Prediction sample:', predictions[0].argmax(1).toString());
console.log('Corresponding labels:', labels[0].toString());
}
// Logs the probabilites that the first image is each digit 0-9 and it's label
function logProbabilityExample(predictions, labels) {
console.log('Probabilities:', predictions[0].gather([0]).toString());
console.log('Label:', labels[0].gather([0]).toString());
}
function logParams(params) {
for (let [name, tensor] of params) {
console.log('name:', name);
console.log('tensor:', tensor);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment