Created
March 31, 2020 15:21
-
-
Save Rishit-dagli/2b5185a4468032ba3d6312591ebda9c0 to your computer and use it in GitHub Desktop.
Model predictions with TF.js
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function testModel(model, inputData, normalizationData) { | |
const {inputMax, inputMin, labelMin, labelMax} = normalizationData; | |
// Generate predictions for a uniform range of numbers between 0 and 1; | |
// We un-normalize the data by doing the inverse of the min-max scaling | |
// that we did earlier. | |
const [xs, preds] = tf.tidy(() => { | |
const xs = tf.linspace(0, 1, 100); | |
const preds = model.predict(xs.reshape([100, 1])); | |
const unNormXs = xs | |
.mul(inputMax.sub(inputMin)) | |
.add(inputMin); | |
const unNormPreds = preds | |
.mul(labelMax.sub(labelMin)) | |
.add(labelMin); | |
// Un-normalize the data | |
return [unNormXs.dataSync(), unNormPreds.dataSync()]; | |
}); | |
const predictedPoints = Array.from(xs).map((val, i) => { | |
return {x: val, y: preds[i]} | |
}); | |
const originalPoints = inputData.map(d => ({ | |
x: d.horsepower, y: d.mpg, | |
})); | |
tfvis.render.scatterplot( | |
{name: 'Model Predictions vs Original Data'}, | |
{values: [originalPoints, predictedPoints], series: ['original', 'predicted']}, | |
{ | |
xLabel: 'Horsepower', | |
yLabel: 'MPG', | |
height: 300 | |
} | |
); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment