-
-
Save dangerdak/3edcdfe6b96aea8b7363b1d571c74147 to your computer and use it in GitHub Desktop.
PropelML: run trained model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const fs = require('fs'); | |
const path = require('path'); | |
const pr = require('propel'); // Uncomment if you installed using npm | |
// const pr = require('./src/api'); // Comment out if you installed using npm | |
run(); | |
async function run() { | |
// These are the default values used in example.js | |
const experimentDir = 'exp001'; | |
const layers = ['L1', 'L2', 'L3']; | |
const params = await loadParams(experimentDir, layers); | |
const testData = pr.dataset("mnist/test").batch(128); | |
// would be better to use tensors instead of arrays of tensors | |
const allPredictions = []; | |
const allLabels = []; | |
const accuracies = []; | |
for (const batchPromise of testData) { | |
const { images, labels } = await batchPromise; | |
allLabels.push(labels); | |
const batchPrediction = images.rescale([0, 255], [-1, 1]) | |
.linear("L1", params, 200).relu() | |
.linear("L2", params, 100).relu() | |
.linear("L3", params, 10) | |
.softmax(); | |
// batchPrediction is [128, 10] ie for each image it contains an array of | |
// probabilities that it belongs in each class | |
allPredictions.push(batchPrediction); | |
accuracies.push(batchAccuracy(batchPrediction, labels)); | |
} | |
// Find accuracy accross all batches | |
const accuracy = accuracies.reduce((total, accuracy) => total.add(accuracy), pr.tensor(0)) / accuracies.length; | |
console.log('Accuracy:', accuracy); | |
logProbabilityExample(allPredictions, allLabels); | |
//logSample(allPredictions, allLabels); | |
} | |
// Returns parameters loaded from latest checkpoint | |
async function loadParams(expDir, layers) { | |
// relative paths dont work due to propelml bug | |
expDir = path.join(process.env.HOME, '.propel', expDir); | |
const params = pr.params(); | |
// load contents from npy files into a tensors and then into params | |
for (const layer of layers) { | |
for (const paramType of ['bias', 'weights']) { | |
let file = path.join(expDir, latestCheckpoint(expDir), layer, `${paramType}.npy`); | |
let tensor = await pr.load(file); | |
params.define(`${layer}/${paramType}`, () => tensor); | |
} | |
} | |
return params; | |
} | |
// Returns accuracy for a given batch of predictions and labels | |
function batchAccuracy(predictionBatch, labelBatch) { | |
// argmax finds the index of the largest value along the provided axis of | |
// a tensor - we take the index where the probability is the highest | |
// to represent the prediction of the model and compare this to the | |
// corresponding label, resulting in a tensor of 1s and 0s where the | |
// prediction matches or doesn't match respectively. The mean of this tensor | |
// is the accuracy of the model for the sample provided (the batch) | |
return predictionBatch.argmax(1).equal(labelBatch).reduceMean(0); | |
} | |
// Returns dirname of most recent checkpoint within provided experiment directory | |
function latestCheckpoint(expDir) { | |
const checkpoints = fs.readdirSync(expDir); | |
return checkpoints.sort((a, b) => b - a)[0]; | |
} | |
// Logs a sample of predicted/inferred digits and their actual labels | |
function logSample(predictions, labels) { | |
console.log(' Prediction sample:', predictions[0].argmax(1).toString()); | |
console.log('Corresponding labels:', labels[0].toString()); | |
} | |
// Logs the probabilites that the first image is each digit 0-9 and it's label | |
function logProbabilityExample(predictions, labels) { | |
console.log('Probabilities:', predictions[0].gather([0]).toString()); | |
console.log('Label:', labels[0].gather([0]).toString()); | |
} | |
function logParams(params) { | |
for (let [name, tensor] of params) { | |
console.log('name:', name); | |
console.log('tensor:', tensor); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment