Skip to content

Instantly share code, notes, and snippets.

@caisq
Last active April 23, 2019 12:57
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save caisq/3fc0beb6597f42d66be806c6692f310d to your computer and use it in GitHub Desktop.
Save caisq/3fc0beb6597f42d66be806c6692f310d to your computer and use it in GitHub Desktop.
Breaking API Changes and New APIs in TensorFlow.js 1.0 Through Examples

1. Model-loading function name changes

Rationale: Indicate type of model being loaded more explicitly in code.

For models converted from Keras and models saved from TensorFlow.js itself:

// Before: 
await tf.loadModel('http://server/model.json');

// After:
await tf.loadLayersModel('http://server/model.json');

For models converted from TensorFlow SavedModels, Session Bundles, and Frozen Models:

// Before:
await tf.loadFrozenModel('http://server/model.pb', 'http://server/weights-manifest.json');

// After:
await tf.loadGraphModel('http://server/model.json');
// Also note the fact that models converted from TensorFlow now uses a single file
// to hold the model topology and the manifest of weights.

2. Model-loading function argument consolidation

Rationale: Consolidate the growing list of optional arguments of the function.

// Before:
const strict = false;
await tf.loadModel('http://server/model.json', strict);

// After:
await tf.loadLayersModel('http://server/model.json', {strict: false});

3. tf.io.browserHTTPRequest() argument consolidation

Rationale: Consolidate the growing list of optional arguments of the function.

const requestInit = {credentials: 'include'};
const fetchFunc = myCustomFetchFunc;

// Before:
await tf.loadModel(tf.io.browserHTTPRequest(
    'http://server/model.json', requestInit, null, fetchFunc));

// After:
await tf.loadLayersModel(tf.io.browserHTTPRequest('http://server/model.json', {
  requestInit,
  fetchFunc,
  onProgress: fraction => console.log(`Model loading is ${fraction * 100}% done`)
}));
// Also note the function name change and the new `onProgress` field.

4. Layers API model: name change

Rationale: Reflect the exact type of model; avoid confusion with models from TensorFlow GraphDef-based models.

// Before:
console.log(myModel instanceof tf.Model);

// After:
console.log(myModel instanceof tf.LayersModel);

5. Tensor.get() is removed; Use Tensor.array() instead.

Rationale: Repeated get() call is a common mistake that harms performance.

const x = tf.randomNormal([3, 3]);

// Before:
console.log(x.get(2, 1));

// After:
const nestedArray = await x.array();
console.log(nestedArray[2][1]):

6. Tensor.buffer() is now async

Rationale: Encourage async usage to reduce the likelihood of blocking UI.

const x = tf.randomNormal([3, 3]);

// Before:
const myBuffer = x.buffer();

// After:
const myBuffer = await x.buffer();
// or
const myBuffer = x.bufferSync();

7. Put fromPixels() and toPixels() under the tf.browser namespace

Rationale: Make it clear that these functions don't work in Node.js.

// Before:
const imageTensor = tf.fromPixels(canvasElement1);
await tf.toPixels(imageTensor, canvasElement2);

// After:
const imageTensor = tf.browser.fromPixels(canvasEleemnt1);
await tf.browser.toPixels(imageTensor, canvasElement2);

8. tf.batchNormalization() is replaced by tf.batchNorm(); Adjust argument order

Rationale: To be consistent with TensorFlow (Python) and more logical.

// Before:
tf.batchNormalization(x, mean, variance, varianceEpsilon, scale, offset);

// After:
tf.batchNorm(x, mean, variance, offset, scale, varianceEpisilon);

9. Error message argument of tf.util.assert() is now a function that returns a string

Rationale: Avoid unnecessary overhead of string creation.

// Before:
tf.util.assert(x.rank >= 3, `x is required to be 3D or higher, but is ${x.rank}D`);

// After:
tf.util.assert(x.rank >= 3, () => `x is required to be 3D or higher, but is ${x.rank}D`);

10. tf.data.Dataset.forEach() is renamed to tf.data.Dataset.forEachSync()

Rationale: Make it clear that the method is async. Avoid confusion with the built-in forEach construct.

// Before:
await myDataset.forEach(item => console.log(item));

// After:
await myDataset.forEachAsync(item => console.log(item));

11. tf.LayersModel.fitDataset() now expects each item to be an object with two fields: xs and ys, instead of an array of two elements.

Rationale: Make it explicit what are the features and what are the labels (targets). Avoid surprising behavior during batching.

const model = tf.sequential({layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});

const xs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]];
const ys = [[0], [1], [2]];
const xDataset = tf.data.array(xs).map(x => tf.tensor(x));
const yDataset = tf.data.array(ys).map(y => tf.tensor(y));

// Before:
const dataset = tf.data.zip([xDataset, yDataset]).batch(2);
await model.fitDataset(dataset, {epochs: 4});

// After:
const dataset = tf.data.zip({xs: xDataset, ys: yDataset}).batch(2);
// Note that the features are specified explicitly by the key 'xs' and
// the labels (targets) by the key 'ys'.
await model.fitDataset(dataset, {epochs: 4});

12. tfjs-vis: always put the HTML element the first in argument list

Rationale: Consistency among all tfjs-vis functions.

// Before:
tfvis.render.linechart(dataToPlot, divElement, options);
tfvis.render.scatterplot(dataToPlot, divElement, options);
tfvis.render.barchart(dataToPlot, divElement, options);
tfvis.render.histogram(dataToPlot, divElement, options);
tfvis.render.heatmap(dataToPlot, divElement);

// After:
tfvis.render.linechart(divElement, dataToPlot, options);
tfvis.render.scatterplot(divElement, dataToPlot, options);
tfvis.render.barchart(divElement, dataToPlot, options);
tfvis.render.histogram(divElement, dataToPlot, options);
tfvis.render.heatmap(divElement, dataToPlot);

13. tfjs-vis: confusionMatrix argument changes

Rationale: Succinctness and clarity.

// Before:
tfvis.show.confusionMatrix(divElement, confusionMatrix, classNames);

// After:
tfvis.show.confusionMatrix(divElement, {
  values: confusionMatrix, 
  tickLabels: classNames
});

14. tfjsvis: tfvis.render.heatmap() and tfvis.show.confusionMatrix(): rename labels as tickLabels

Rationale: Avoid confusion between axis labels and tick labels.

// Before:
tfvis.show.confusionMatrix(divElement, {
  values,
  labels: ['apple', 'banana', 'orange']
});
tfvis.render.heatmap(divElement, {
  values,
  xLabels: ['x1', 'x2', 'x3'],
  yLabels: ['y1', 'y2', 'y3']
});

// After:
tfvis.show.confusionMatrix(divElement, {
  values,
  tickLabels: ['apple', 'banana', 'orange']
});
tfvis.render.heatmap(divElement, {
  values,
  xTickLabels: ['x1', 'x2', 'x3'],
  yTickLabels: ['y1', 'y2', 'y3']
});

15. tf.data.generator() has been renamed to tf.data.func()

Rationale: In v1.0.0, tf.data.generator() expects an actual JavaScript generator. In 0.x, it took a plain function and hence was a misnomer.

let i = 0;

// Before:
const gen = tf.data.generator(() => ({value: ++i, done: i === 10}));

// After:
const gen = tf.data.func(() => ({value: ++i, done: i === 10}));
@hpssjellis
Copy link

Note: tf.matrixTimesVector() quietly disappeared from the Tensorflowjs API at version 12.1 sensibly being deprecated as tf.dot() handles it now.
Strange that tf.matrixTimesVector() works for version 12.0 but has been completely removed from the API for ANY version.

@RadEdje
Copy link

RadEdje commented Apr 10, 2019

Hi, I'm currently on tensorflow 2.0 alpha (python) and tensorflow.js 1.0.0.
I just used the tensorflowjs_converter.
I have 3 shard1of3, 1of2 and 3of3.bin.
they've all loaded in the browser. I checked the element inspector. They're all at 200 and "OK". Also recognized as octet stream.
I'm using tf.loadLayersModel but nothing happens. the model json and the bin files have loaded. even the settings file. Is there any other file i should be waiting for to guarantee the async await function when loading in tfjs does not fail? thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment