1. Model-loading function name changes
Rationale: Indicate type of model being loaded more explicitly in code.
For models converted from Keras and models saved from TensorFlow.js itself:
// Before:
await tf.loadModel('http://server/model.json');
// After:
await tf.loadLayersModel('http://server/model.json');For models converted from TensorFlow SavedModels, Session Bundles, and Frozen Models:
// Before:
await tf.loadFrozenModel('http://server/model.pb', 'http://server/weights-manifest.json');
// After:
await tf.loadGraphModel('http://server/model.json');
// Also note the fact that models converted from TensorFlow now uses a single file
// to hold the model topology and the manifest of weights.2. Model-loading function argument consolidation
Rationale: Consolidate the growing list of optional arguments of the function.
// Before:
const strict = false;
await tf.loadModel('http://server/model.json', strict);
// After:
await tf.loadLayersModel('http://server/model.json', {strict: false});3. tf.io.browserHTTPRequest() argument consolidation
Rationale: Consolidate the growing list of optional arguments of the function.
const requestInit = {credentials: 'include'};
const fetchFunc = myCustomFetchFunc;
// Before:
await tf.loadModel(tf.io.browserHTTPRequest(
'http://server/model.json', requestInit, null, fetchFunc));
// After:
await tf.loadLayersModel(tf.io.browserHTTPRequest('http://server/model.json', {
requestInit,
fetchFunc,
onProgress: fraction => console.log(`Model loading is ${fraction * 100}% done`)
}));
// Also note the function name change and the new `onProgress` field.4. Layers API model: name change
Rationale: Reflect the exact type of model; avoid confusion with models from TensorFlow GraphDef-based models.
// Before:
console.log(myModel instanceof tf.Model);
// After:
console.log(myModel instanceof tf.LayersModel);5. Tensor.get() is removed; Use Tensor.array() instead.
Rationale: Repeated get() call is a common mistake that harms performance.
const x = tf.randomNormal([3, 3]);
// Before:
console.log(x.get(2, 1));
// After:
const nestedArray = await x.array();
console.log(nestedArray[2][1]):6. Tensor.buffer() is now async
Rationale: Encourage async usage to reduce the likelihood of blocking UI.
const x = tf.randomNormal([3, 3]);
// Before:
const myBuffer = x.buffer();
// After:
const myBuffer = await x.buffer();
// or
const myBuffer = x.bufferSync();7. Put fromPixels() and toPixels() under the tf.browser namespace
Rationale: Make it clear that these functions don't work in Node.js.
// Before:
const imageTensor = tf.fromPixels(canvasElement1);
await tf.toPixels(imageTensor, canvasElement2);
// After:
const imageTensor = tf.browser.fromPixels(canvasEleemnt1);
await tf.browser.toPixels(imageTensor, canvasElement2);8. tf.batchNormalization() is replaced by tf.batchNorm(); Adjust argument order
Rationale: To be consistent with TensorFlow (Python) and more logical.
// Before:
tf.batchNormalization(x, mean, variance, varianceEpsilon, scale, offset);
// After:
tf.batchNorm(x, mean, variance, offset, scale, varianceEpisilon);9. Error message argument of tf.util.assert() is now a function that returns a string
Rationale: Avoid unnecessary overhead of string creation.
// Before:
tf.util.assert(x.rank >= 3, `x is required to be 3D or higher, but is ${x.rank}D`);
// After:
tf.util.assert(x.rank >= 3, () => `x is required to be 3D or higher, but is ${x.rank}D`);10. tf.data.Dataset.forEach() is renamed to tf.data.Dataset.forEachSync()
Rationale: Make it clear that the method is async. Avoid confusion with the built-in forEach construct.
// Before:
await myDataset.forEach(item => console.log(item));
// After:
await myDataset.forEachAsync(item => console.log(item));11. tf.LayersModel.fitDataset() now expects each item to be an object with two fields: xs and ys, instead of an array of two elements.
Rationale: Make it explicit what are the features and what are the labels (targets). Avoid surprising behavior during batching.
const model = tf.sequential({layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
const xs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]];
const ys = [[0], [1], [2]];
const xDataset = tf.data.array(xs).map(x => tf.tensor(x));
const yDataset = tf.data.array(ys).map(y => tf.tensor(y));
// Before:
const dataset = tf.data.zip([xDataset, yDataset]).batch(2);
await model.fitDataset(dataset, {epochs: 4});
// After:
const dataset = tf.data.zip({xs: xDataset, ys: yDataset}).batch(2);
// Note that the features are specified explicitly by the key 'xs' and
// the labels (targets) by the key 'ys'.
await model.fitDataset(dataset, {epochs: 4});12. tfjs-vis: always put the HTML element the first in argument list
Rationale: Consistency among all tfjs-vis functions.
// Before:
tfvis.render.linechart(dataToPlot, divElement, options);
tfvis.render.scatterplot(dataToPlot, divElement, options);
tfvis.render.barchart(dataToPlot, divElement, options);
tfvis.render.histogram(dataToPlot, divElement, options);
tfvis.render.heatmap(dataToPlot, divElement);
// After:
tfvis.render.linechart(divElement, dataToPlot, options);
tfvis.render.scatterplot(divElement, dataToPlot, options);
tfvis.render.barchart(divElement, dataToPlot, options);
tfvis.render.histogram(divElement, dataToPlot, options);
tfvis.render.heatmap(divElement, dataToPlot);13. tfjs-vis: confusionMatrix argument changes
Rationale: Succinctness and clarity.
// Before:
tfvis.show.confusionMatrix(divElement, confusionMatrix, classNames);
// After:
tfvis.show.confusionMatrix(divElement, {
values: confusionMatrix,
tickLabels: classNames
});14. tfjsvis: tfvis.render.heatmap() and tfvis.show.confusionMatrix(): rename labels as tickLabels
Rationale: Avoid confusion between axis labels and tick labels.
// Before:
tfvis.show.confusionMatrix(divElement, {
values,
labels: ['apple', 'banana', 'orange']
});
tfvis.render.heatmap(divElement, {
values,
xLabels: ['x1', 'x2', 'x3'],
yLabels: ['y1', 'y2', 'y3']
});
// After:
tfvis.show.confusionMatrix(divElement, {
values,
tickLabels: ['apple', 'banana', 'orange']
});
tfvis.render.heatmap(divElement, {
values,
xTickLabels: ['x1', 'x2', 'x3'],
yTickLabels: ['y1', 'y2', 'y3']
});15. tf.data.generator() has been renamed to tf.data.func()
Rationale: In v1.0.0, tf.data.generator() expects an actual JavaScript generator. In 0.x, it took a plain function and hence was a misnomer.
let i = 0;
// Before:
const gen = tf.data.generator(() => ({value: ++i, done: i === 10}));
// After:
const gen = tf.data.func(() => ({value: ++i, done: i === 10}));
Note: tf.matrixTimesVector() quietly disappeared from the Tensorflowjs API at version 12.1 sensibly being deprecated as tf.dot() handles it now.
Strange that tf.matrixTimesVector() works for version 12.0 but has been completely removed from the API for ANY version.