Skip to content

Instantly share code, notes, and snippets.

@mizchi

mizchi/tf.es.js

Created Jul 27, 2020
Embed
What would you like to do?
This file has been truncated, but you can view the full file.
/* prebuilt es */
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true.
const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';
/**
* The environment contains evaluated flags as well as the registered platform.
* This is always used as a global singleton and can be retrieved with
* `tf.env()`.
*/
/** @doc {heading: 'Environment'} */
class Environment {
// tslint:disable-next-line: no-any
constructor(global) {
this.global = global;
this.flags = {};
this.flagRegistry = {};
this.urlFlags = {};
this.populateURLFlags();
}
setPlatform(platformName, platform) {
if (this.platform != null) {
console.warn(`Platform ${this.platformName} has already been set. ` +
`Overwriting the platform with ${platform}.`);
}
this.platformName = platformName;
this.platform = platform;
}
registerFlag(flagName, evaluationFn, setHook) {
this.flagRegistry[flagName] = { evaluationFn, setHook };
// Override the flag value from the URL. This has to happen here because the
// environment is initialized before flags get registered.
if (this.urlFlags[flagName] != null) {
const flagValue = this.urlFlags[flagName];
console.warn(`Setting feature override from URL ${flagName}: ${flagValue}.`);
this.set(flagName, flagValue);
}
}
async getAsync(flagName) {
if (flagName in this.flags) {
return this.flags[flagName];
}
this.flags[flagName] = await this.evaluateFlag(flagName);
return this.flags[flagName];
}
get(flagName) {
if (flagName in this.flags) {
return this.flags[flagName];
}
const flagValue = this.evaluateFlag(flagName);
if (flagValue instanceof Promise) {
throw new Error(`Flag ${flagName} cannot be synchronously evaluated. ` +
`Please use getAsync() instead.`);
}
this.flags[flagName] = flagValue;
return this.flags[flagName];
}
getNumber(flagName) {
return this.get(flagName);
}
getBool(flagName) {
return this.get(flagName);
}
getFlags() {
return this.flags;
}
// For backwards compatibility.
get features() {
return this.flags;
}
set(flagName, value) {
if (this.flagRegistry[flagName] == null) {
throw new Error(`Cannot set flag ${flagName} as it has not been registered.`);
}
this.flags[flagName] = value;
if (this.flagRegistry[flagName].setHook != null) {
this.flagRegistry[flagName].setHook(value);
}
}
evaluateFlag(flagName) {
if (this.flagRegistry[flagName] == null) {
throw new Error(`Cannot evaluate flag '${flagName}': no evaluation function found.`);
}
return this.flagRegistry[flagName].evaluationFn();
}
setFlags(flags) {
this.flags = Object.assign({}, flags);
}
reset() {
this.flags = {};
this.urlFlags = {};
this.populateURLFlags();
}
populateURLFlags() {
if (typeof this.global === 'undefined' ||
typeof this.global.location === 'undefined' ||
typeof this.global.location.search === 'undefined') {
return;
}
const urlParams = getQueryParams(this.global.location.search);
if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');
keyValues.forEach(keyValue => {
const [key, value] = keyValue.split(':');
this.urlFlags[key] = parseValue(key, value);
});
}
}
}
function getQueryParams(queryString) {
const params = {};
queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (s, ...t) => {
decodeParam(params, t[0], t[1]);
return t.join('=');
});
return params;
}
function decodeParam(params, name, value) {
params[decodeURIComponent(name)] = decodeURIComponent(value || '');
}
function parseValue(flagName, value) {
value = value.toLowerCase();
if (value === 'true' || value === 'false') {
return value === 'true';
}
else if (`${+value}` === value) {
return +value;
}
throw new Error(`Could not parse value flag value ${value} for flag ${flagName}.`);
}
/**
* Returns the current environment (a global singleton).
*
* The environment object contains the evaluated feature values as well as the
* active platform.
*/
/** @doc {heading: 'Environment'} */
function env() {
return ENV;
}
let ENV = null;
function setEnvironmentGlobal(environment) {
ENV = environment;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// Note that the identifier globalNameSpace is scoped to this module, but will
// always resolve to the same global object regardless of how the module is
// resolved.
// tslint:disable-next-line:no-any
let globalNameSpace;
// tslint:disable-next-line:no-any
function getGlobalNamespace() {
if (globalNameSpace == null) {
// tslint:disable-next-line:no-any
let ns;
if (typeof (window) !== 'undefined') {
ns = window;
}
else if (typeof (global) !== 'undefined') {
ns = global;
}
else if (typeof (process) !== 'undefined') {
ns = process;
}
else if (typeof (self) !== 'undefined') {
ns = self;
}
else {
throw new Error('Could not find a global object');
}
globalNameSpace = ns;
}
return globalNameSpace;
}
// tslint:disable-next-line:no-any
function getGlobalMap() {
const ns = getGlobalNamespace();
if (ns._tfGlobals == null) {
ns._tfGlobals = new Map();
}
return ns._tfGlobals;
}
/**
* Returns a globally accessible 'singleton' object.
*
* @param key the name of the object
* @param init a function to initialize to initialize this object
* the first time it is fetched.
*/
function getGlobal(key, init) {
const globalMap = getGlobalMap();
if (globalMap.has(key)) {
return globalMap.get(key);
}
else {
const singleton = init();
globalMap.set(key, singleton);
return globalMap.get(key);
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const kernelRegistry = getGlobal('kernelRegistry', () => new Map());
const gradRegistry = getGlobal('gradRegistry', () => new Map());
/**
* Returns the kernel function (code) associated with the provided names.
*
* @param kernelName The official name of the kernel.
* @param backendName The official name of the backend.
*/
function getKernel(kernelName, backendName) {
const key = makeKey(kernelName, backendName);
return kernelRegistry.get(key);
}
/**
* Returns the registered gradient info associated with the provided kernel.
* @param kernelName The official TF kernel name.
*/
function getGradient(kernelName) {
return gradRegistry.get(kernelName);
}
function getKernelsForBackend(backendName) {
const it = kernelRegistry.entries();
const result = [];
while (true) {
const { done, value } = it.next();
if (done) {
break;
}
const [key, config] = value;
const [backend,] = key.split('_');
if (backend === backendName) {
result.push(config);
}
}
return result;
}
/**
* Registers the function (forward pass) for the kernel in a global registry.
*
* @param config A config object with the following properties:
* - `kernelName` The official name of the kernel.
* - `backendName` The official name of the backend.
* - `kernelFunc` The function to run during the forward pass of the kernel.
* - `setupFunc` Optional. Gets called once, after the backend initializes.
* - `disposeFunc` Optional. Gets called once, right before the backend is
* disposed.
*/
function registerKernel(config) {
const { kernelName, backendName } = config;
const key = makeKey(kernelName, backendName);
if (kernelRegistry.has(key)) {
console.warn(`The kernel '${kernelName}' for backend ` +
`'${backendName}' is already registered`);
}
kernelRegistry.set(key, config);
}
/**
* Registers a gradient function for a given kernel in the global registry,
* to be used during the back-propagation of that kernel.
*
* @param config An object with the following properties:
* - `kernelName` The name of the kernel that the gradient function is for.
* - `gradFunc` The function to run during back-propagation.
*/
function registerGradient(config) {
const { kernelName } = config;
if (gradRegistry.has(kernelName)) {
// TODO (yassogba) after 3.0 assess whether we need to keep this gated
// to debug mode.
if (env().getBool('DEBUG')) {
console.warn(`Overriding the gradient for '${kernelName}'`);
}
}
gradRegistry.set(kernelName, config);
}
/**
* Removes the kernel function from the registry.
*
* @param kernelName The official name of the kernel.
* @param backendName The official name of the backend.
*
*/
function unregisterKernel(kernelName, backendName) {
const key = makeKey(kernelName, backendName);
if (!kernelRegistry.has(key)) {
throw new Error(`The kernel '${kernelName}' for backend ` +
`'${backendName}' is not registered`);
}
kernelRegistry.delete(key);
}
/** Removes the registered gradient from the global registry. */
function unregisterGradient(kernelName) {
if (!gradRegistry.has(kernelName)) {
throw new Error(`The gradient '${kernelName}' for backend is not registered`);
}
gradRegistry.delete(kernelName);
}
function makeKey(kernelName, backendName) {
return `${backendName}_${kernelName}`;
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Shuffles the array in-place using Fisher-Yates algorithm.
*
* ```js
* const a = [1, 2, 3, 4, 5];
* tf.util.shuffle(a);
* console.log(a);
* ```
*
* @param array The array to shuffle in-place.
*/
/** @doc {heading: 'Util', namespace: 'util'} */
// tslint:disable-next-line:no-any
function shuffle(array) {
let counter = array.length;
let temp = 0;
let index = 0;
// While there are elements in the array
while (counter > 0) {
// Pick a random index
index = (Math.random() * counter) | 0;
// Decrease counter by 1
counter--;
// And swap the last element with it
temp = array[counter];
array[counter] = array[index];
array[index] = temp;
}
}
/** Clamps a value to a specified range. */
function clamp(min, x, max) {
return Math.max(min, Math.min(x, max));
}
function nearestLargerEven(val) {
return val % 2 === 0 ? val : val + 1;
}
function sum(arr) {
let sum = 0;
for (let i = 0; i < arr.length; i++) {
sum += arr[i];
}
return sum;
}
/**
* Returns a sample from a uniform [a, b) distribution.
*
* @param a The minimum support (inclusive).
* @param b The maximum support (exclusive).
* @return A pseudorandom number on the half-open interval [a,b).
*/
function randUniform(a, b) {
const r = Math.random();
return (b * r) + (1 - r) * a;
}
/** Returns the squared Euclidean distance between two vectors. */
function distSquared(a, b) {
let result = 0;
for (let i = 0; i < a.length; i++) {
const diff = Number(a[i]) - Number(b[i]);
result += diff * diff;
}
return result;
}
/**
* Asserts that the expression is true. Otherwise throws an error with the
* provided message.
*
* ```js
* const x = 2;
* tf.util.assert(x === 2, 'x is not 2');
* ```
*
* @param expr The expression to assert (as a boolean).
* @param msg A function that returns the message to report when throwing an
* error. We use a function for performance reasons.
*/
/** @doc {heading: 'Util', namespace: 'util'} */
function assert(expr, msg) {
if (!expr) {
throw new Error(typeof msg === 'string' ? msg : msg());
}
}
function assertShapesMatch(shapeA, shapeB, errorMessagePrefix = '') {
assert(arraysEqual(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);
}
function assertNonNull(a) {
assert(a != null, () => `The input to the tensor constructor must be a non-null value.`);
}
// NOTE: We explicitly type out what T extends instead of any so that
// util.flatten on a nested array of number doesn't try to infer T as a
// number[][], causing us to explicitly type util.flatten<number>().
/**
* Flattens an arbitrarily nested array.
*
* ```js
* const a = [[1, 2], [3, 4], [5, [6, [7]]]];
* const flat = tf.util.flatten(a);
* console.log(flat);
* ```
*
* @param arr The nested array to flatten.
* @param result The destination array which holds the elements.
* @param skipTypedArray If true, avoids flattening the typed arrays. Defaults
* to false.
*/
/** @doc {heading: 'Util', namespace: 'util'} */
function flatten(arr, result = [], skipTypedArray = false) {
if (result == null) {
result = [];
}
if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) {
for (let i = 0; i < arr.length; ++i) {
flatten(arr[i], result, skipTypedArray);
}
}
else {
result.push(arr);
}
return result;
}
/**
* Returns the size (number of elements) of the tensor given its shape.
*
* ```js
* const shape = [3, 4, 2];
* const size = tf.util.sizeFromShape(shape);
* console.log(size);
* ```
*/
/** @doc {heading: 'Util', namespace: 'util'} */
function sizeFromShape(shape) {
if (shape.length === 0) {
// Scalar.
return 1;
}
let size = shape[0];
for (let i = 1; i < shape.length; i++) {
size *= shape[i];
}
return size;
}
function isScalarShape(shape) {
return shape.length === 0;
}
function arraysEqual(n1, n2) {
if (n1 === n2) {
return true;
}
if (n1 == null || n2 == null) {
return false;
}
if (n1.length !== n2.length) {
return false;
}
for (let i = 0; i < n1.length; i++) {
if (n1[i] !== n2[i]) {
return false;
}
}
return true;
}
function isInt(a) {
return a % 1 === 0;
}
function tanh(x) {
// tslint:disable-next-line:no-any
if (Math.tanh != null) {
// tslint:disable-next-line:no-any
return Math.tanh(x);
}
if (x === Infinity) {
return 1;
}
else if (x === -Infinity) {
return -1;
}
else {
const e2x = Math.exp(2 * x);
return (e2x - 1) / (e2x + 1);
}
}
function sizeToSquarishShape(size) {
const width = Math.ceil(Math.sqrt(size));
return [width, Math.ceil(size / width)];
}
/**
* Creates a new array with randomized indicies to a given quantity.
*
* ```js
* const randomTen = tf.util.createShuffledIndices(10);
* console.log(randomTen);
* ```
*
* @param number Quantity of how many shuffled indicies to create.
*/
/** @doc {heading: 'Util', namespace: 'util'} */
function createShuffledIndices(n) {
const shuffledIndices = new Uint32Array(n);
for (let i = 0; i < n; ++i) {
shuffledIndices[i] = i;
}
shuffle(shuffledIndices);
return shuffledIndices;
}
function rightPad(a, size) {
if (size <= a.length) {
return a;
}
return a + ' '.repeat(size - a.length);
}
function repeatedTry(checkFn, delayFn = (counter) => 0, maxCounter) {
return new Promise((resolve, reject) => {
let tryCount = 0;
const tryFn = () => {
if (checkFn()) {
resolve();
return;
}
tryCount++;
const nextBackoff = delayFn(tryCount);
if (maxCounter != null && tryCount >= maxCounter) {
reject();
return;
}
setTimeout(tryFn, nextBackoff);
};
tryFn();
});
}
/**
* Given the full size of the array and a shape that may contain -1 as the
* implicit dimension, returns the inferred shape where -1 is replaced.
* E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3].
*
* @param shape The shape, which may contain -1 in some dimension.
* @param size The full size (number of elements) of the array.
* @return The inferred shape where -1 is replaced with the inferred size.
*/
function inferFromImplicitShape(shape, size) {
let shapeProd = 1;
let implicitIdx = -1;
for (let i = 0; i < shape.length; ++i) {
if (shape[i] >= 0) {
shapeProd *= shape[i];
}
else if (shape[i] === -1) {
if (implicitIdx !== -1) {
throw Error(`Shapes can only have 1 implicit size. ` +
`Found -1 at dim ${implicitIdx} and dim ${i}`);
}
implicitIdx = i;
}
else if (shape[i] < 0) {
throw Error(`Shapes can not be < 0. Found ${shape[i]} at dim ${i}`);
}
}
if (implicitIdx === -1) {
if (size > 0 && size !== shapeProd) {
throw Error(`Size(${size}) must match the product of shape ${shape}`);
}
return shape;
}
if (shapeProd === 0) {
throw Error(`Cannot infer the missing size in [${shape}] when ` +
`there are 0 elements`);
}
if (size % shapeProd !== 0) {
throw Error(`The implicit shape can't be a fractional number. ` +
`Got ${size} / ${shapeProd}`);
}
const newShape = shape.slice();
newShape[implicitIdx] = size / shapeProd;
return newShape;
}
function parseAxisParam(axis, shape) {
const rank = shape.length;
// Normalize input
axis = axis == null ? shape.map((s, i) => i) : [].concat(axis);
// Check for valid range
assert(axis.every(ax => ax >= -rank && ax < rank), () => `All values in axis param must be in range [-${rank}, ${rank}) but ` +
`got axis ${axis}`);
// Check for only integers
assert(axis.every(ax => isInt(ax)), () => `All values in axis param must be integers but ` +
`got axis ${axis}`);
// Handle negative axis.
return axis.map(a => a < 0 ? rank + a : a);
}
/** Reduces the shape by removing all dimensions of shape 1. */
function squeezeShape(shape, axis) {
const newShape = [];
const keptDims = [];
const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
const axes = (axis == null || isEmptyArray) ?
null :
parseAxisParam(axis, shape).sort();
let j = 0;
for (let i = 0; i < shape.length; ++i) {
if (axes != null) {
if (axes[j] === i && shape[i] !== 1) {
throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`);
}
if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
if (axes[j] <= i) {
j++;
}
}
if (shape[i] !== 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
}
return { newShape, keptDims };
}
function getTypedArrayFromDType(dtype, size) {
let values = null;
if (dtype == null || dtype === 'float32') {
values = new Float32Array(size);
}
else if (dtype === 'int32') {
values = new Int32Array(size);
}
else if (dtype === 'bool') {
values = new Uint8Array(size);
}
else {
throw new Error(`Unknown data type ${dtype}`);
}
return values;
}
function getArrayFromDType(dtype, size) {
let values = null;
if (dtype == null || dtype === 'float32') {
values = new Float32Array(size);
}
else if (dtype === 'int32') {
values = new Int32Array(size);
}
else if (dtype === 'bool') {
values = new Uint8Array(size);
}
else if (dtype === 'string') {
values = new Array(size);
}
else {
throw new Error(`Unknown data type ${dtype}`);
}
return values;
}
function checkConversionForErrors(vals, dtype) {
for (let i = 0; i < vals.length; i++) {
const num = vals[i];
if (isNaN(num) || !isFinite(num)) {
throw Error(`A tensor of type ${dtype} being uploaded contains ${num}.`);
}
}
}
/** Returns true if the dtype is valid. */
function isValidDtype(dtype) {
return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' ||
dtype === 'int32' || dtype === 'string';
}
/**
* Returns true if the new type can't encode the old type without loss of
* precision.
*/
function hasEncodingLoss(oldType, newType) {
if (newType === 'complex64') {
return false;
}
if (newType === 'float32' && oldType !== 'complex64') {
return false;
}
if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {
return false;
}
if (newType === 'bool' && oldType === 'bool') {
return false;
}
return true;
}
function isTypedArray(a) {
return a instanceof Float32Array || a instanceof Int32Array ||
a instanceof Uint8Array;
}
function bytesPerElement(dtype) {
if (dtype === 'float32' || dtype === 'int32') {
return 4;
}
else if (dtype === 'complex64') {
return 8;
}
else if (dtype === 'bool') {
return 1;
}
else {
throw new Error(`Unknown dtype ${dtype}`);
}
}
/**
* Returns the approximate number of bytes allocated in the string array - 2
* bytes per character. Computing the exact bytes for a native string in JS is
* not possible since it depends on the encoding of the html page that serves
* the website.
*/
function bytesFromStringArray(arr) {
if (arr == null) {
return 0;
}
let bytes = 0;
arr.forEach(x => bytes += x.length);
return bytes;
}
/** Returns true if the value is a string. */
function isString(value) {
return typeof value === 'string' || value instanceof String;
}
function isBoolean(value) {
return typeof value === 'boolean';
}
function isNumber(value) {
return typeof value === 'number';
}
function inferDtype(values) {
if (Array.isArray(values)) {
return inferDtype(values[0]);
}
if (values instanceof Float32Array) {
return 'float32';
}
else if (values instanceof Int32Array || values instanceof Uint8Array) {
return 'int32';
}
else if (isNumber(values)) {
return 'float32';
}
else if (isString(values)) {
return 'string';
}
else if (isBoolean(values)) {
return 'bool';
}
return 'float32';
}
function isFunction(f) {
return !!(f && f.constructor && f.call && f.apply);
}
function nearestDivisor(size, start) {
for (let i = start; i < size; ++i) {
if (size % i === 0) {
return i;
}
}
return size;
}
function computeStrides(shape) {
const rank = shape.length;
if (rank < 2) {
return [];
}
// Last dimension has implicit stride of 1, thus having D-1 (instead of D)
// strides.
const strides = new Array(rank - 1);
strides[rank - 2] = shape[rank - 1];
for (let i = rank - 3; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
}
function toTypedArray(a, dtype) {
if (dtype === 'string') {
throw new Error('Cannot convert a string[] to a TypedArray');
}
if (Array.isArray(a)) {
a = flatten(a);
}
if (env().getBool('DEBUG')) {
checkConversionForErrors(a, dtype);
}
if (noConversionNeeded(a, dtype)) {
return a;
}
if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
return new Float32Array(a);
}
else if (dtype === 'int32') {
return new Int32Array(a);
}
else if (dtype === 'bool') {
const bool = new Uint8Array(a.length);
for (let i = 0; i < bool.length; ++i) {
if (Math.round(a[i]) !== 0) {
bool[i] = 1;
}
}
return bool;
}
else {
throw new Error(`Unknown data type ${dtype}`);
}
}
function createNestedArray(offset, shape, a) {
const ret = new Array();
if (shape.length === 1) {
const d = shape[0];
for (let i = 0; i < d; i++) {
ret[i] = a[offset + i];
}
}
else {
const d = shape[0];
const rest = shape.slice(1);
const len = rest.reduce((acc, c) => acc * c);
for (let i = 0; i < d; i++) {
ret[i] = createNestedArray(offset + i * len, rest, a);
}
}
return ret;
}
// Provide a nested array of TypedArray in given shape.
function toNestedArray(shape, a) {
if (shape.length === 0) {
// Scalar type should return a single number.
return a[0];
}
const size = shape.reduce((acc, c) => acc * c);
if (size === 0) {
// A tensor with shape zero should be turned into empty list.
return [];
}
if (size !== a.length) {
throw new Error(`[${shape}] does not match the input size ${a.length}.`);
}
return createNestedArray(0, shape, a);
}
function noConversionNeeded(a, dtype) {
return (a instanceof Float32Array && dtype === 'float32') ||
(a instanceof Int32Array && dtype === 'int32') ||
(a instanceof Uint8Array && dtype === 'bool');
}
function makeOnesTypedArray(size, dtype) {
const array = makeZerosTypedArray(size, dtype);
for (let i = 0; i < array.length; i++) {
array[i] = 1;
}
return array;
}
function makeZerosTypedArray(size, dtype) {
if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
return new Float32Array(size);
}
else if (dtype === 'int32') {
return new Int32Array(size);
}
else if (dtype === 'bool') {
return new Uint8Array(size);
}
else {
throw new Error(`Unknown data type ${dtype}`);
}
}
/**
* Make nested `TypedArray` filled with zeros.
* @param shape The shape information for the nested array.
* @param dtype dtype of the array element.
*/
function makeZerosNestedTypedArray(shape, dtype) {
const size = shape.reduce((prev, curr) => prev * curr, 1);
if (dtype == null || dtype === 'float32') {
return toNestedArray(shape, new Float32Array(size));
}
else if (dtype === 'int32') {
return toNestedArray(shape, new Int32Array(size));
}
else if (dtype === 'bool') {
return toNestedArray(shape, new Uint8Array(size));
}
else {
throw new Error(`Unknown data type ${dtype}`);
}
}
/**
* Returns the current high-resolution time in milliseconds relative to an
* arbitrary time in the past. It works across different platforms (node.js,
* browsers).
*
* ```js
* console.log(tf.util.now());
* ```
*/
/** @doc {heading: 'Util', namespace: 'util'} */
function now() {
return env().platform.now();
}
function assertNonNegativeIntegerDimensions(shape) {
shape.forEach(dimSize => {
assert(Number.isInteger(dimSize) && dimSize >= 0, () => `Tensor must have a shape comprised of positive integers but got ` +
`shape [${shape}].`);
});
}
/**
* Returns a platform-specific implementation of
* [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API).
*
* If `fetch` is defined on the global object (`window`, `process`, etc.),
* `tf.util.fetch` returns that function.
*
* If not, `tf.util.fetch` returns a platform-specific solution.
*
* ```js
* const resource = await tf.util.fetch('https://unpkg.com/@tensorflow/tfjs');
* // handle response
* ```
*/
/** @doc {heading: 'Util'} */
function fetch$1(path, requestInits) {
return env().platform.fetch(path, requestInits);
}
/**
* Encodes the provided string into bytes using the provided encoding scheme.
*
* @param s The string to encode.
* @param encoding The encoding scheme. Defaults to utf-8.
*
*/
/** @doc {heading: 'Util'} */
function encodeString(s, encoding = 'utf-8') {
encoding = encoding || 'utf-8';
return env().platform.encode(s, encoding);
}
/**
* Decodes the provided bytes into a string using the provided encoding scheme.
* @param bytes The bytes to decode.
*
* @param encoding The encoding scheme. Defaults to utf-8.
*/
/** @doc {heading: 'Util'} */
function decodeString(bytes, encoding = 'utf-8') {
encoding = encoding || 'utf-8';
return env().platform.decode(bytes, encoding);
}
/**
* Computes flat index for a given location (multidimentionsal index) in a
* Tensor/multidimensional array.
*
* @param locs Location in the tensor.
* @param rank Rank of the tensor.
* @param strides Tensor strides.
*/
function locToIndex(locs, rank, strides) {
if (rank === 0) {
return 0;
}
else if (rank === 1) {
return locs[0];
}
let index = locs[locs.length - 1];
for (let i = 0; i < locs.length - 1; ++i) {
index += strides[i] * locs[i];
}
return index;
}
/**
* Computes the location (multidimensional index) in a tensor/multidimentional
* array for a given flat index.
*
* @param index Index in flat array.
* @param rank Rank of tensor.
* @param strides Strides of tensor.
*/
function indexToLoc(index, rank, strides) {
if (rank === 0) {
return [];
}
else if (rank === 1) {
return [index];
}
const locs = new Array(rank);
for (let i = 0; i < locs.length - 1; ++i) {
locs[i] = Math.floor(index / strides[i]);
index -= locs[i] * strides[i];
}
locs[locs.length - 1] = index;
return locs;
}
var util = {
__proto__: null,
shuffle: shuffle,
clamp: clamp,
nearestLargerEven: nearestLargerEven,
sum: sum,
randUniform: randUniform,
distSquared: distSquared,
assert: assert,
assertShapesMatch: assertShapesMatch,
assertNonNull: assertNonNull,
flatten: flatten,
sizeFromShape: sizeFromShape,
isScalarShape: isScalarShape,
arraysEqual: arraysEqual,
isInt: isInt,
tanh: tanh,
sizeToSquarishShape: sizeToSquarishShape,
createShuffledIndices: createShuffledIndices,
rightPad: rightPad,
repeatedTry: repeatedTry,
inferFromImplicitShape: inferFromImplicitShape,
parseAxisParam: parseAxisParam,
squeezeShape: squeezeShape,
getTypedArrayFromDType: getTypedArrayFromDType,
getArrayFromDType: getArrayFromDType,
checkConversionForErrors: checkConversionForErrors,
isValidDtype: isValidDtype,
hasEncodingLoss: hasEncodingLoss,
isTypedArray: isTypedArray,
bytesPerElement: bytesPerElement,
bytesFromStringArray: bytesFromStringArray,
isString: isString,
isBoolean: isBoolean,
isNumber: isNumber,
inferDtype: inferDtype,
isFunction: isFunction,
nearestDivisor: nearestDivisor,
computeStrides: computeStrides,
toTypedArray: toTypedArray,
toNestedArray: toNestedArray,
makeOnesTypedArray: makeOnesTypedArray,
makeZerosTypedArray: makeZerosTypedArray,
makeZerosNestedTypedArray: makeZerosNestedTypedArray,
now: now,
assertNonNegativeIntegerDimensions: assertNonNegativeIntegerDimensions,
fetch: fetch$1,
encodeString: encodeString,
decodeString: decodeString,
locToIndex: locToIndex,
indexToLoc: indexToLoc
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
class Profiler {
constructor(backendTimer, logger) {
this.backendTimer = backendTimer;
this.logger = logger;
if (logger == null) {
this.logger = new Logger();
}
}
profileKernel(kernelName, inputs, f) {
let outputs;
const holdResultWrapperFn = () => {
outputs = f();
};
const timer = this.backendTimer.time(holdResultWrapperFn);
outputs.forEach(r => {
// Dangling promise here because we don't want to propagate up
// asynchronicity.
r.data().then(vals => {
checkComputationForErrors(vals, r.dtype, kernelName);
timer.then(timing => {
let extraInfo = '';
if (timing.getExtraProfileInfo != null) {
extraInfo = timing.getExtraProfileInfo();
}
this.logger.logKernelProfile(kernelName, r, vals, timing.kernelMs, inputs, extraInfo);
});
});
});
return outputs;
}
}
function checkComputationForErrors(vals, dtype, kernelName) {
if (dtype !== 'float32') {
// Only floating point computations will generate NaN values
return false;
}
for (let i = 0; i < vals.length; i++) {
const num = vals[i];
if (isNaN(num) || !isFinite(num)) {
// Throwing custom exception so behavior is testable.
console.warn(`Found ${num} in the result of '${kernelName}'`);
return true;
}
}
return false;
}
class Logger {
logKernelProfile(name, result, vals, timeMs, inputs, extraInfo) {
const time = typeof timeMs === 'number' ? rightPad(`${timeMs}ms`, 9) :
timeMs['error'];
const paddedName = rightPad(name, 25);
const rank = result.rank;
const size = result.size;
const shape = rightPad(result.shape.toString(), 14);
let inputShapesDescription = '';
for (const name in inputs) {
const input = inputs[name];
// The input might be a non-tensor (e.g HTMLImageElement), in which case
// we claim the output shape as input shape.
const inputShape = input.shape || result.shape;
const inputRank = inputShape.length;
inputShapesDescription +=
`${name}: ${inputRank}D ${inputRank > 0 ? inputShape : ''} `;
}
console.log(`%c${paddedName}\t%c${time}\t%c${rank}D ${shape}\t%c${size}\t%c${inputShapesDescription}\t%c${extraInfo}`, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue');
}
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes a list of TapeNodes that connect x to y, filtering everything else
* out and preserving the order of the original tape elements.
*
* @param tape The tape elements to filter.
* @param xs The input Tensors.
* @param y The output Tensor.
*/
function getFilteredNodesXToY(tape, xs, y) {
// Forward pass to compute all the nodes and Tensors that are transitively a
// function of x.
const tensorsFromX = {};
const nodesFromX = {};
for (let i = 0; i < xs.length; i++) {
tensorsFromX[xs[i].id] = true;
}
for (let i = 0; i < tape.length; i++) {
const node = tape[i];
const nodeInputs = node.inputs;
for (const inputName in nodeInputs) {
const input = nodeInputs[inputName];
let anyInputFromX = false;
for (let j = 0; j < xs.length; j++) {
if (tensorsFromX[input.id]) {
node.outputs.forEach(output => tensorsFromX[output.id] = true);
anyInputFromX = true;
nodesFromX[node.id] = true;
break;
}
}
if (anyInputFromX) {
break;
}
}
}
// Backward pass to find all of the nodes and Tensors that lead to y.
const tensorsLeadToY = {};
tensorsLeadToY[y.id] = true;
const nodesToY = {};
for (let i = tape.length - 1; i >= 0; i--) {
const node = tape[i];
const nodeInputs = node.inputs;
// If any of the outputs lead to y, mark all of the inputs as leading to y.
for (let j = 0; j < node.outputs.length; j++) {
if (tensorsLeadToY[node.outputs[j].id]) {
for (const inputName in nodeInputs) {
tensorsLeadToY[nodeInputs[inputName].id] = true;
nodesToY[node.id] = true;
}
break;
}
}
}
// Return the paths that come from x and lead to y.
const filteredTape = [];
for (let i = 0; i < tape.length; i++) {
const node = tape[i];
if (nodesFromX[node.id] && nodesToY[node.id]) {
// Prune the inputs from the node that aren't a function of x.
const prunedInputs = {};
for (const inputName in node.inputs) {
const nodeInput = node.inputs[inputName];
if (tensorsFromX[nodeInput.id]) {
prunedInputs[inputName] = nodeInput;
}
}
// Copy the node and overwrite inputsAndArgs to the pruned version.
const prunedNode = Object.assign({}, node);
prunedNode.inputs = prunedInputs;
prunedNode.outputs = node.outputs;
filteredTape.push(prunedNode);
}
}
return filteredTape;
}
/**
* Backpropagate gradients through the filtered TapeNodes.
*
* @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map
* is mutated by this method.
* @param filteredTape The filtered TapeNodes to backprop through.
*/
function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy) {
// Walk the tape backward and keep a map of Tensor to its gradient.
for (let i = filteredTape.length - 1; i >= 0; i--) {
const node = filteredTape[i];
const dys = [];
node.outputs.forEach(o => {
const gradTensor = tensorAccumulatedGradientMap[o.id];
if (gradTensor != null) {
dys.push(gradTensor);
}
else {
// This particular output is not in the back-propagation subgraph, so it
// does not affect the final output, thus we put null for its dy.
dys.push(null);
}
});
if (node.gradient == null) {
throw new Error(`Cannot compute gradient: gradient function not found ` +
`for ${node.kernelName}.`);
}
// Backprop dy through this node and accumulate gradients over the inputs.
const inputGradients = node.gradient(dys);
for (const inputName in node.inputs) {
if (!(inputName in inputGradients)) {
throw new Error(`Cannot backprop through input ${inputName}. ` +
`Available gradients found: ${Object.keys(inputGradients)}.`);
}
// Call the gradient function.
const dx = tidy(() => inputGradients[inputName]());
if (dx.dtype !== 'float32') {
throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` +
`${inputName} must have 'float32' dtype, but has '${dx.dtype}'`);
}
const x = node.inputs[inputName];
if (!arraysEqual(dx.shape, x.shape)) {
throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` +
`'${inputName}' has shape '${dx.shape}', which does not match ` +
`the shape of the input '${x.shape}'`);
}
if (tensorAccumulatedGradientMap[x.id] == null) {
tensorAccumulatedGradientMap[x.id] = dx;
}
else {
const curGradient = tensorAccumulatedGradientMap[x.id];
tensorAccumulatedGradientMap[x.id] = curGradient.add(dx);
curGradient.dispose();
}
}
}
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// Maximum number of values before we decide to show ellipsis.
const FORMAT_LIMIT_NUM_VALS = 20;
// Number of first and last values to show when displaying a, b,...,y, z.
const FORMAT_NUM_FIRST_LAST_VALS = 3;
// Number of significant digits to show.
const FORMAT_NUM_SIG_DIGITS = 7;
function tensorToString(vals, shape, dtype, verbose) {
const strides = computeStrides(shape);
const padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides);
const rank = shape.length;
const valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol);
const lines = ['Tensor'];
if (verbose) {
lines.push(` dtype: ${dtype}`);
lines.push(` rank: ${rank}`);
lines.push(` shape: [${shape}]`);
lines.push(` values:`);
}
lines.push(valsLines.map(l => ' ' + l).join('\n'));
return lines.join('\n');
}
function computeMaxSizePerColumn(vals, shape, dtype, strides) {
const n = sizeFromShape(shape);
const numCols = strides[strides.length - 1];
const padPerCol = new Array(numCols).fill(0);
const rank = shape.length;
const valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals;
if (rank > 1) {
for (let row = 0; row < n / numCols; row++) {
const offset = row * numCols;
for (let j = 0; j < numCols; j++) {
padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length);
}
}
}
return padPerCol;
}
function valToString(val, pad, dtype) {
let valStr;
if (Array.isArray(val)) {
valStr = `${parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS))} + ` +
`${parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS))}j`;
}
else if (isString(val)) {
valStr = `'${val}'`;
}
else if (dtype === 'bool') {
valStr = boolNumToString(val);
}
else {
valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString();
}
return rightPad(valStr, pad);
}
function boolNumToString(v) {
return v === 0 ? 'false' : 'true';
}
function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast = true) {
const storagePerElement = dtype === 'complex64' ? 2 : 1;
const size = shape[0];
const rank = shape.length;
if (rank === 0) {
if (dtype === 'complex64') {
const complexTuple = createComplexTuples(vals);
return [valToString(complexTuple[0], 0, dtype)];
}
if (dtype === 'bool') {
return [boolNumToString(vals[0])];
}
return [vals[0].toString()];
}
if (rank === 1) {
if (size > FORMAT_LIMIT_NUM_VALS) {
const firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement;
let firstVals = Array.from(vals.slice(0, firstValsSize));
let lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement));
if (dtype === 'complex64') {
firstVals = createComplexTuples(firstVals);
lastVals = createComplexTuples(lastVals);
}
return [
'[' +
firstVals.map((x, i) => valToString(x, padPerCol[i], dtype))
.join(', ') +
', ..., ' +
lastVals
.map((x, i) => valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype))
.join(', ') +
']'
];
}
const displayVals = dtype === 'complex64' ? createComplexTuples(vals) :
Array.from(vals);
return [
'[' +
displayVals.map((x, i) => valToString(x, padPerCol[i], dtype))
.join(', ') +
']'
];
}
// The array is rank 2 or more.
const subshape = shape.slice(1);
const substrides = strides.slice(1);
const stride = strides[0] * storagePerElement;
const lines = [];
if (size > FORMAT_LIMIT_NUM_VALS) {
for (let i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) {
const start = i * stride;
const end = start + stride;
lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false /* isLast */));
}
lines.push('...');
for (let i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) {
const start = i * stride;
const end = start + stride;
lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */));
}
}
else {
for (let i = 0; i < size; i++) {
const start = i * stride;
const end = start + stride;
lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */));
}
}
const sep = rank === 2 ? ',' : '';
lines[0] = '[' + lines[0] + sep;
for (let i = 1; i < lines.length - 1; i++) {
lines[i] = ' ' + lines[i] + sep;
}
let newLineSep = ',\n';
for (let i = 2; i < rank; i++) {
newLineSep += '\n';
}
lines[lines.length - 1] =
' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep);
return lines;
}
function createComplexTuples(vals) {
const complexTuples = [];
for (let i = 0; i < vals.length; i += 2) {
complexTuples.push([vals[i], vals[i + 1]]);
}
return complexTuples;
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* A mutable object, similar to `tf.Tensor`, that allows users to set values
* at locations before converting to an immutable `tf.Tensor`.
*
* See `tf.buffer` for creating a tensor buffer.
*/
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
class TensorBuffer {
constructor(shape, dtype, values) {
this.dtype = dtype;
this.shape = shape.slice();
this.size = sizeFromShape(shape);
if (values != null) {
const n = values.length;
assert(n === this.size, () => `Length of values '${n}' does not match the size ` +
`inferred by the shape '${this.size}'.`);
}
if (dtype === 'complex64') {
throw new Error(`complex64 dtype TensorBuffers are not supported. Please create ` +
`a TensorBuffer for the real and imaginary parts separately and ` +
`call tf.complex(real, imag).`);
}
this.values = values || getArrayFromDType(dtype, this.size);
this.strides = computeStrides(shape);
}
/**
* Sets a value in the buffer at a given location.
*
* @param value The value to set.
* @param locs The location indices.
*/
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
set(value, ...locs) {
if (locs.length === 0) {
locs = [0];
}
assert(locs.length === this.rank, () => `The number of provided coordinates (${locs.length}) must ` +
`match the rank (${this.rank})`);
const index = this.locToIndex(locs);
this.values[index] = value;
}
/**
* Returns the value in the buffer at the provided location.
*
* @param locs The location indices.
*/
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
get(...locs) {
if (locs.length === 0) {
locs = [0];
}
let i = 0;
for (const loc of locs) {
if (loc < 0 || loc >= this.shape[i]) {
const msg = `Requested out of range element at ${locs}. ` +
` Buffer shape=${this.shape}`;
throw new Error(msg);
}
i++;
}
let index = locs[locs.length - 1];
for (let i = 0; i < locs.length - 1; ++i) {
index += this.strides[i] * locs[i];
}
return this.values[index];
}
locToIndex(locs) {
if (this.rank === 0) {
return 0;
}
else if (this.rank === 1) {
return locs[0];
}
let index = locs[locs.length - 1];
for (let i = 0; i < locs.length - 1; ++i) {
index += this.strides[i] * locs[i];
}
return index;
}
indexToLoc(index) {
if (this.rank === 0) {
return [];
}
else if (this.rank === 1) {
return [index];
}
const locs = new Array(this.shape.length);
for (let i = 0; i < locs.length - 1; ++i) {
locs[i] = Math.floor(index / this.strides[i]);
index -= locs[i] * this.strides[i];
}
locs[locs.length - 1] = index;
return locs;
}
get rank() {
return this.shape.length;
}
/**
* Creates an immutable `tf.Tensor` object from the buffer.
*/
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
toTensor() {
return trackerFn().makeTensor(this.values, this.shape, this.dtype);
}
}
// For tracking tensor creation and disposal.
let trackerFn = null;
// Used by chaining methods to call into ops.
let opHandler = null;
/**
* An external consumer can register itself as the tensor tracker. This way
* the Tensor class can notify the tracker for every tensor created and
* disposed.
*/
function setTensorTracker(fn) {
trackerFn = fn;
}
/**
* An external consumer can register itself as the op handler. This way the
* Tensor class can have chaining methods that call into ops via the op
* handler.
*/
function setOpHandler(handler) {
opHandler = handler;
}
/**
* A `tf.Tensor` object represents an immutable, multidimensional array of
* numbers that has a shape and a data type.
*
* See `tf.tensor` for details on how to create a `tf.Tensor`.
*/
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
class Tensor {
constructor(shape, dtype, dataId, id) {
/** Whether this tensor has been globally kept. */
this.kept = false;
this.isDisposedInternal = false;
this.shape = shape.slice();
this.dtype = dtype || 'float32';
this.size = sizeFromShape(shape);
this.strides = computeStrides(shape);
this.dataId = dataId;
this.id = id;
this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher');
}
get rank() {
return this.shape.length;
}
/**
* Returns a promise of `tf.TensorBuffer` that holds the underlying data.
*/
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
async buffer() {
const vals = await this.data();
return opHandler.buffer(this.shape, this.dtype, vals);
}
/** Returns a `tf.TensorBuffer` that holds the underlying data. */
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
bufferSync() {
return opHandler.buffer(this.shape, this.dtype, this.dataSync());
}
/**
* Returns the tensor data as a nested array. The transfer of data is done
* asynchronously.
*/
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
async array() {
const vals = await this.data();
return toNestedArray(this.shape, vals);
}
/**
* Returns the tensor data as a nested array. The transfer of data is done
* synchronously.
*/
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
arraySync() {
return toNestedArray(this.shape, this.dataSync());
}
/**
* Asynchronously downloads the values from the `tf.Tensor`. Returns a
* promise of `TypedArray` that resolves when the computation has finished.
*/
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
async data() {
this.throwIfDisposed();
const data = trackerFn().read(this.dataId);
if (this.dtype === 'string') {
const bytes = await data;
try {
return bytes.map(b => decodeString(b));
}
catch (_a) {
throw new Error('Failed to decode the string bytes into utf-8. ' +
'To get the original bytes, call tensor.bytes().');
}
}
return data;
}
/**
* Synchronously downloads the values from the `tf.Tensor`. This blocks the
* UI thread until the values are ready, which can cause performance issues.
*/
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
dataSync() {
this.throwIfDisposed();
const data = trackerFn().readSync(this.dataId);
if (this.dtype === 'string') {
try {
return data.map(b => decodeString(b));
}
catch (_a) {
throw new Error('Failed to decode the string bytes into utf-8. ' +
'To get the original bytes, call tensor.bytes().');
}
}
return data;
}
/** Returns the underlying bytes of the tensor's data. */
async bytes() {
this.throwIfDisposed();
const data = await trackerFn().read(this.dataId);
if (this.dtype === 'string') {
return data;
}
else {
return new Uint8Array(data.buffer);
}
}
/**
* Disposes `tf.Tensor` from memory.
*/
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
dispose() {
if (this.isDisposed) {
return;
}
trackerFn().disposeTensor(this);
this.isDisposedInternal = true;
}
get isDisposed() {
return this.isDisposedInternal;
}
throwIfDisposed() {
if (this.isDisposed) {
throw new Error(`Tensor is disposed.`);
}
}
/**
* Prints the `tf.Tensor`. See `tf.print` for details.
*
* @param verbose Whether to print verbose information about the tensor,
* including dtype and size.
*/
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
print(verbose = false) {
return opHandler.print(this, verbose);
}
/** Returns a copy of the tensor. See `tf.clone` for details. */
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
clone() {
this.throwIfDisposed();
return opHandler.clone(this);
}
/**
* Returns a human-readable description of the tensor. Useful for logging.
*/
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
toString(verbose = false) {
const vals = this.dataSync();
return tensorToString(vals, this.shape, this.dtype, verbose);
}
cast(dtype) {
this.throwIfDisposed();
return opHandler.cast(this, dtype);
}
variable(trainable = true, name, dtype) {
this.throwIfDisposed();
return trackerFn().makeVariable(this, trainable, name, dtype);
}
}
Object.defineProperty(Tensor, Symbol.hasInstance, {
value: (instance) => {
return !!instance && instance.dataId != null && instance.shape != null &&
instance.dtype != null;
}
});
/**
* A mutable `tf.Tensor`, useful for persisting state, e.g. for training.
*/
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
class Variable extends Tensor {
constructor(initialValue, trainable, name, tensorId) {
super(initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId);
this.trainable = trainable;
this.name = name;
}
/**
* Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have
* the same shape and dtype as the old `tf.Tensor`.
*
* @param newValue New tensor to be assigned to this variable.
*/
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
assign(newValue) {
if (newValue.dtype !== this.dtype) {
throw new Error(`dtype of the new value (${newValue.dtype}) and ` +
`previous value (${this.dtype}) must match`);
}
if (!arraysEqual(newValue.shape, this.shape)) {
throw new Error(`shape of the new value (${newValue.shape}) and ` +
`previous value (${this.shape}) must match`);
}
trackerFn().disposeTensor(this);
this.dataId = newValue.dataId;
trackerFn().incRef(this, null /* backend */);
}
dispose() {
trackerFn().disposeVariable(this);
this.isDisposedInternal = true;
}
}
Object.defineProperty(Variable, Symbol.hasInstance, {
value: (instance) => {
return instance instanceof Tensor && instance.assign != null &&
instance.assign instanceof Function;
}
});
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var Rank;
(function (Rank) {
Rank["R0"] = "R0";
Rank["R1"] = "R1";
Rank["R2"] = "R2";
Rank["R3"] = "R3";
Rank["R4"] = "R4";
Rank["R5"] = "R5";
Rank["R6"] = "R6";
})(Rank || (Rank = {}));
// Looks for upcasting types. Used, for example, in operations with mixed dtype
// inputs.
var UpcastInt32AndMap;
(function (UpcastInt32AndMap) {
UpcastInt32AndMap["float32"] = "float32";
UpcastInt32AndMap["int32"] = "int32";
UpcastInt32AndMap["bool"] = "int32";
UpcastInt32AndMap["complex64"] = "complex64";
})(UpcastInt32AndMap || (UpcastInt32AndMap = {}));
var UpcastBoolAndMap;
(function (UpcastBoolAndMap) {
UpcastBoolAndMap["float32"] = "float32";
UpcastBoolAndMap["int32"] = "int32";
UpcastBoolAndMap["bool"] = "bool";
UpcastBoolAndMap["complex64"] = "complex64";
})(UpcastBoolAndMap || (UpcastBoolAndMap = {}));
var UpcastFloat32AndMap;
(function (UpcastFloat32AndMap) {
UpcastFloat32AndMap["float32"] = "float32";
UpcastFloat32AndMap["int32"] = "float32";
UpcastFloat32AndMap["bool"] = "float32";
UpcastFloat32AndMap["complex64"] = "complex64";
})(UpcastFloat32AndMap || (UpcastFloat32AndMap = {}));
var UpcastComplex64AndMap;
(function (UpcastComplex64AndMap) {
UpcastComplex64AndMap["float32"] = "complex64";
UpcastComplex64AndMap["int32"] = "complex64";
UpcastComplex64AndMap["bool"] = "complex64";
UpcastComplex64AndMap["complex64"] = "complex64";
})(UpcastComplex64AndMap || (UpcastComplex64AndMap = {}));
const upcastTypeMap = {
'float32': UpcastFloat32AndMap,
'int32': UpcastInt32AndMap,
'bool': UpcastBoolAndMap,
'complex64': UpcastComplex64AndMap
};
function upcastType(typeA, typeB) {
if (typeA === 'string' || typeB === 'string') {
if (typeA === 'string' && typeB === 'string') {
return 'string';
}
throw new Error(`Can not upcast ${typeA} with ${typeB}`);
}
return upcastTypeMap[typeA][typeB];
}
/** Returns the output type after summation. */
function sumOutType(type) {
return upcastType(type, 'int32');
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function makeTypesMatch(a, b) {
if (a.dtype === b.dtype) {
return [a, b];
}
const dtype = upcastType(a.dtype, b.dtype);
return [a.cast(dtype), b.cast(dtype)];
}
function assertTypesMatch(a, b) {
assert(a.dtype === b.dtype, () => `The dtypes of the first(${a.dtype}) and` +
` second(${b.dtype}) input must match`);
}
function isTensorInList(tensor, tensorList) {
return tensorList.some(x => x.id === tensor.id);
}
/**
* Extracts any `Tensor`s found within the provided object.
*
* @param container an object that may be a `Tensor` or may directly contain
* `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it
* is safe to pass any object here, except that `Promise`s are not
* supported.
* @returns An array of `Tensors` found within the passed object. If the
* argument is simply a `Tensor', a list containing that `Tensor` is
* returned. If the object is not a `Tensor` or does not
* contain `Tensors`, an empty list is returned.
*/
function getTensorsInContainer(result) {
const list = [];
const seen = new Set();
walkTensorContainer(result, list, seen);
return list;
}
function walkTensorContainer(container, list, seen) {
if (container == null) {
return;
}
if (container instanceof Tensor) {
list.push(container);
return;
}
if (!isIterable(container)) {
return;
}
// Iteration over keys works also for arrays.
const iterable = container;
for (const k in iterable) {
const val = iterable[k];
if (!seen.has(val)) {
seen.add(val);
walkTensorContainer(val, list, seen);
}
}
}
// tslint:disable-next-line:no-any
function isIterable(obj) {
return Array.isArray(obj) || typeof obj === 'object';
}
var tensor_util = {
__proto__: null,
makeTypesMatch: makeTypesMatch,
assertTypesMatch: assertTypesMatch,
isTensorInList: isTensorInList,
getTensorsInContainer: getTensorsInContainer
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
class EngineState {
constructor() {
// Public since optimizers will use it.
this.registeredVariables = {};
this.nextTapeNodeId = 0;
this.numBytes = 0;
this.numTensors = 0;
this.numStringTensors = 0;
this.numDataBuffers = 0;
// Number of nested tf.grad() statements when computing higher-order
// gradients. E.g. `1` for first-order gradients and `2` for second-order
// gradients. Used to track if the tape should be removed after a backprop.
this.gradientDepth = 0;
// Number of nested kernel calls. When kernel depth is greater than 1, we turn
// off the tape.
this.kernelDepth = 0;
this.scopeStack = [];
/**
* Keeps track of the number of data moves during a kernel execution. We
* maintain a stack since kernels can call other kernels, recursively.
*/
this.numDataMovesStack = [];
this.nextScopeId = 0;
this.tensorInfo = new WeakMap();
this.profiling = false;
this.activeProfile = { newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null };
}
dispose() {
for (const variableName in this.registeredVariables) {
this.registeredVariables[variableName].dispose();
}
}
}
class Engine {
constructor(ENV) {
this.ENV = ENV;
this.registry = {};
this.registryFactory = {};
this.pendingBackendInitId = 0;
this.state = new EngineState();
}
async ready() {
if (this.pendingBackendInit != null) {
return this.pendingBackendInit.then(() => { });
}
if (this.backendInstance != null) {
return;
}
const sortedBackends = this.getSortedBackends();
for (let i = 0; i < sortedBackends.length; i++) {
const backendName = sortedBackends[i];
const success = await this.initializeBackend(backendName).success;
if (success) {
await this.setBackend(backendName);
return;
}
}
throw new Error(`Could not initialize any backends, all backend initializations ` +
`failed.`);
}
get backend() {
if (this.pendingBackendInit != null) {
throw new Error(`Backend '${this.backendName}' has not yet been initialized. Make ` +
`sure to await tf.ready() or await tf.setBackend() before calling ` +
`other methods`);
}
if (this.backendInstance == null) {
const { name, asyncInit } = this.initializeBackendsAndReturnBest();
if (asyncInit) {
throw new Error(`The highest priority backend '${name}' has not yet been ` +
`initialized. Make sure to await tf.ready() or ` +
`await tf.setBackend() before calling other methods`);
}
this.setBackend(name);
}
return this.backendInstance;
}
backendNames() {
return Object.keys(this.registryFactory);
}
findBackend(backendName) {
if (!(backendName in this.registry)) {
// If the backend hasn't been initialized but we have a registry entry for
// it, initialize it and return it.
if (backendName in this.registryFactory) {
const { asyncInit } = this.initializeBackend(backendName);
if (asyncInit) {
// Backend is not ready yet.
return null;
}
}
else {
return null;
}
}
return this.registry[backendName];
}
findBackendFactory(backendName) {
if (!(backendName in this.registryFactory)) {
return null;
}
return this.registryFactory[backendName].factory;
}
registerBackend(backendName, factory, priority = 1) {
if (backendName in this.registryFactory) {
console.warn(`${backendName} backend was already registered. ` +
`Reusing existing backend factory.`);
return false;
}
this.registryFactory[backendName] = { factory, priority };
return true;
}
async setBackend(backendName) {
if (this.registryFactory[backendName] == null) {
throw new Error(`Backend name '${backendName}' not found in registry`);
}
this.backendName = backendName;
if (this.registry[backendName] == null) {
this.backendInstance = null;
const { success, asyncInit } = this.initializeBackend(backendName);
const result = asyncInit ? await success : success;
if (!result) {
return false;
}
}
this.backendInstance = this.registry[backendName];
this.setupRegisteredKernels();
// Reset the profiler.
this.profiler = new Profiler(this.backendInstance);
return true;
}
setupRegisteredKernels() {
const kernels = getKernelsForBackend(this.backendName);
kernels.forEach(kernel => {
if (kernel.setupFunc != null) {
kernel.setupFunc(this.backendInstance);
}
});
}
disposeRegisteredKernels(backendName) {
const kernels = getKernelsForBackend(backendName);
kernels.forEach(kernel => {
if (kernel.disposeFunc != null) {
kernel.disposeFunc(this.registry[backendName]);
}
});
}
/**
* Initializes a backend by looking up the backend name in the factory
* registry and calling the factory method. Returns a boolean representing
* whether the initialization of the backend suceeded. Throws an error if
* there is no backend in the factory registry.
*/
initializeBackend(backendName) {
const registryFactoryEntry = this.registryFactory[backendName];
if (registryFactoryEntry == null) {
throw new Error(`Cannot initialize backend ${backendName}, no registration found.`);
}
try {
const backend = registryFactoryEntry.factory();
// Test if the factory returns a promise.
if (Promise.resolve(backend) === backend) {
const promiseId = ++this.pendingBackendInitId;
const success = backend
.then(backendInstance => {
// Outdated promise. Another backend was set in the meantime.
if (promiseId < this.pendingBackendInitId) {
return false;
}
this.registry[backendName] = backendInstance;
this.pendingBackendInit = null;
return true;
})
.catch(err => {
// Outdated promise. Another backend was set in the meantime.
if (promiseId < this.pendingBackendInitId) {
return false;
}
this.pendingBackendInit = null;
console.warn(`Initialization of backend ${backendName} failed`);
console.warn(err.stack || err.message);
return false;
});
this.pendingBackendInit = success;
return { success, asyncInit: true };
}
else {
this.registry[backendName] = backend;
return { success: true, asyncInit: false };
}
}
catch (err) {
console.warn(`Initialization of backend ${backendName} failed`);
console.warn(err.stack || err.message);
return { success: false, asyncInit: false };
}
}
removeBackend(backendName) {
if (!(backendName in this.registryFactory)) {
throw new Error(`${backendName} backend not found in registry`);
}
if (this.backendName === backendName && this.pendingBackendInit != null) {
// There is a pending promise of the backend we want to remove. Make it
// obsolete.
this.pendingBackendInitId++;
}
if (backendName in this.registry) {
this.disposeRegisteredKernels(backendName);
this.registry[backendName].dispose();
delete this.registry[backendName];
}
delete this.registryFactory[backendName];
// Unset the backend if it is active.
if (this.backendName === backendName) {
this.pendingBackendInit = null;
this.backendName = null;
this.backendInstance = null;
}
}
getSortedBackends() {
if (Object.keys(this.registryFactory).length === 0) {
throw new Error('No backend found in registry.');
}
return Object.keys(this.registryFactory).sort((a, b) => {
// Highest priority comes first.
return this.registryFactory[b].priority -
this.registryFactory[a].priority;
});
}
initializeBackendsAndReturnBest() {
const sortedBackends = this.getSortedBackends();
for (let i = 0; i < sortedBackends.length; i++) {
const backendName = sortedBackends[i];
const { success, asyncInit } = this.initializeBackend(backendName);
if (asyncInit || success) {
return { name: backendName, asyncInit };
}
}
throw new Error(`Could not initialize any backends, all backend initializations ` +
`failed.`);
}
moveData(backend, dataId) {
const info = this.state.tensorInfo.get(dataId);
const srcBackend = info.backend;
const values = this.readSync(dataId);
// Delete the tensor from the old backend and move it to the new
// backend.
srcBackend.disposeData(dataId);
info.backend = backend;
backend.move(dataId, values, info.shape, info.dtype);
if (this.shouldCheckForMemLeaks()) {
// Track the number of moves during a kernel execution to correctly
// detect memory leaks.
this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
}
}
tidy(nameOrFn, fn) {
let name = null;
if (fn == null) {
// Called with only 1 argument.
if (typeof nameOrFn !== 'function') {
throw new Error('Please provide a function to tidy()');
}
fn = nameOrFn;
}
else {
// Called with 2 arguments.
if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {
throw new Error('When calling with two arguments, the first argument ' +
'to tidy() must be a string');
}
if (typeof fn !== 'function') {
throw new Error('When calling with two arguments, the 2nd argument ' +
'to tidy() must be a function');
}
name = nameOrFn;
// TODO(nsthorat,smilkov): Do operation logging and performance
// profiling.
}
let result;
return this.scopedRun(() => this.startScope(name), () => this.endScope(result), () => {
result = fn();
if (result instanceof Promise) {
console.error('Cannot return a Promise inside of tidy.');
}
return result;
});
}
scopedRun(start, end, f) {
start();
try {
const res = f();
end();
return res;
}
catch (ex) {
end();
throw ex;
}
}
nextTensorId() {
return Engine.nextTensorId++;
}
nextVariableId() {
return Engine.nextVariableId++;
}
/**
* This method is called instead of the public-facing tensor.clone() when
* saving a tensor for backwards pass. It makes sure to add the clone
* operation to the tape regardless of being called inside a kernel
* execution.
*
* This method will go away once all kernels are modularized since we won't
* need to turn off the tape inside runKernel().
*/
clone(x) {
const y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype);
const inputs = { x };
const grad = (dy) => ({ x: () => dy.toFloat() });
const saved = [];
this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {});
return y;
}
/**
* Execute a kernel with the given name and return the output tensor.
*
* @param kernelName The name of the kernel to execute.
* @param inputs A map of input names to tensors.
* @param attrs A map of attribute names to their values. An attribute is a
* primitive (non-tensor) input to the kernel.
* @param inputsToSave A list of tensors, inputs to save for the backprop
* computation.
* @param outputsToSave A list of booleans, specifying which output to save
* for the backprop computation. These are booleans since the output
* tensors are not visible to the user.
*/
runKernel(kernelName, inputs, attrs, inputsToSave, outputsToSave) {
const forwardFunc = null;
const backwardsFunc = null;
// Call runKernel as a stop-gap until we modularize all kernels.
// Once we modularize all kernels, we will remove the existing
// `runKernelFunc`.
return this.runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave);
}
shouldCheckForMemLeaks() {
return this.ENV.getBool('IS_TEST');
}
checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos) {
const numDataIdsAfter = this.backend.numDataIds();
// Count the number of data ids associated with the result of the kernel.
let numOutputDataIds = 0;
outInfos.forEach(info => {
// Complex numbers allocate 3 data ids, one for 'real', one for
// 'imaginary', and one for the container that holds the former two.
numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1);
});
// Account for the number of moves during kernel execution. A "data move"
// can happen in the middle of a kernel execution, placing a new (key,value)
// pair in the data storage. Since data moves have net zero effect (we
// always remove the data from the old backend), we have to cancel them out
// when detecting memory leaks.
const numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];
const dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;
if (dataIdsLeaked > 0) {
throw new Error(`Backend '${this.backendName}' has an internal memory leak ` +
`(${dataIdsLeaked} data ids) after running '${kernelName}'`);
}
}
/**
* @deprecated Use `runKernel` for newly added kernels. Keep using this method
* only for kernels that are not yet fully modularized.
*/
runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave) {
let outputs;
let saved = [];
const isTapeOn = this.isTapeOn();
if (kernelName == null) {
kernelName =
this.state.activeScope != null ? this.state.activeScope.name : '';
}
const startingBytecount = this.state.numBytes;
const startingNumTensors = this.state.numTensors;
if (this.shouldCheckForMemLeaks()) {
this.state.numDataMovesStack.push(0);
}
let kernelFunc;
const kernel = getKernel(kernelName, this.backendName);
let out;
if (kernel != null) {
kernelFunc = () => {
const numDataIdsBefore = this.backend.numDataIds();
out = kernel.kernelFunc({ inputs, attrs, backend: this.backend });
const outInfos = Array.isArray(out) ? out : [out];
if (this.shouldCheckForMemLeaks()) {
this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
}
const outTensors = outInfos.map(({ dataId, shape, dtype }) => this.makeTensorFromDataId(dataId, shape, dtype));
// Save the inputs and outputs.
// Do not save unless we are recording to the tape. Otherwise it would
// cause a mem leak since we would never run backprop, which disposes
// the kept tensors.
if (isTapeOn) {
let tensorsToSave = this.getTensorsForGradient(kernelName, inputs, outTensors);
if (tensorsToSave == null) {
// Fallback for ops that call runKernelFunc and pass in
// inputsToSave and outputsToSave. Currently this is the set of ops
// with kernel support in the WASM backend. Once those ops and
// respective gradients are modularised we can remove this path.
if (outputsToSave == null) {
outputsToSave = [];
}
const outsToSave = outTensors.filter((_, i) => outputsToSave[i]);
tensorsToSave = (inputsToSave || []).slice().concat(outsToSave);
}
saved = this.saveTensorsForBackwardMode(tensorsToSave);
}
return outTensors;
};
}
else {
const saveFunc = (tensors) => {
// Do not save unless we are recording to the tape. Otherwise it would
// cause a mem leak since we would never run backprop, which disposes
// the kept tensors.
if (!isTapeOn) {
return;
}
saved = tensors.map(tensor => this.keep(this.clone(tensor)));
};
kernelFunc = () => {
const numDataIdsBefore = this.backend.numDataIds();
out = this.tidy(() => forwardFunc(this.backend, saveFunc));
const outs = (Array.isArray(out) ? out : [out]);
if (this.shouldCheckForMemLeaks()) {
this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outs);
}
return outs;
};
}
// Stop recording to a tape when running a kernel.
this.scopedRun(() => this.state.kernelDepth++, () => this.state.kernelDepth--, () => {
if (!this.ENV.getBool('DEBUG')) {
outputs = kernelFunc();
}
else {
outputs = this.profiler.profileKernel(kernelName, inputs, () => kernelFunc());
}
});
if (isTapeOn) {
this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved, attrs);
}
if (this.state.profiling) {
this.state.activeProfile.kernels.push({
name: kernelName,
bytesAdded: this.state.numBytes - startingBytecount,
totalBytesSnapshot: this.state.numBytes,
tensorsAdded: this.state.numTensors - startingNumTensors,
totalTensorsSnapshot: this.state.numTensors,
inputShapes: Object.keys(inputs).map(key => inputs[key].shape),
outputShapes: outputs.map(item => item.shape)
});
}
return (Array.isArray(out) ? outputs : outputs[0]);
}
/**
* Saves tensors used in forward mode for use in backward mode.
*
* @param tensors the list of tensors to save.
*/
saveTensorsForBackwardMode(tensors) {
const saved = tensors.map(tensor => this.keep(this.clone(tensor)));
return saved;
}
/**
* Returns a list of tensors to save for a given gradient calculation.
*
* Returns undefined if their is no registered gradient for this kernel in the
* gradient registry.
*
* @param kernelName name of kernel to look up gradient for.
* @param inputs a map of input tensors.
* @param outputs an array of output tensors from forward mode of kernel.
*/
getTensorsForGradient(kernelName, inputs, outputs) {
const gradConfig = getGradient(kernelName);
if (gradConfig != null) {
const inputsToSave = gradConfig.inputsToSave || [];
const outputsToSave = gradConfig.outputsToSave || [];
// If saveAllInputs is true, all inputs will be saved. Otherwise, inputs
// specified in inputsToSave will be saved.
let inputTensorsToSave;
if (gradConfig.saveAllInputs) {
assert(Array.isArray(inputs), () => 'saveAllInputs is true, expected inputs to be an array.');
inputTensorsToSave = Object.keys(inputs).map((key) => inputs[key]);
}
else {
inputTensorsToSave = inputsToSave.map((inputName) => inputs[inputName]);
}
const outputTensorsToSave = outputs.filter((_, i) => outputsToSave[i]);
return inputTensorsToSave.concat(outputTensorsToSave);
}
// TODO(yassogba) throw exception here once all runkernelFunc calls with
// inputsToSave/outputsToSave are removed
return null;
}
/**
* Internal method used by public APIs for tensor creation. Makes a new
* tensor with the provided shape, dtype and values. It always
* creates a new data id and writes the values to the underlying backend.
*/
makeTensor(values, shape, dtype, backend) {
if (values == null) {
throw new Error('Values passed to engine.makeTensor() are null');
}
dtype = dtype || 'float32';
backend = backend || this.backend;
let backendVals = values;
if (dtype === 'string' && isString(values[0])) {
backendVals = values.map(d => encodeString(d));
}
const dataId = backend.write(backendVals, shape, dtype);
const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
this.incRef(t, backend);
// Count bytes for string tensors.
if (dtype === 'string') {
const info = this.state.tensorInfo.get(dataId);
const newBytes = bytesFromStringArray(backendVals);
this.state.numBytes += newBytes - info.bytes;
info.bytes = newBytes;
}
return t;
}
/**
* Internal method used by backends. Makes a new tensor
* that is a wrapper around an existing data id. It doesn't create
* a new data id, only increments the ref count used in memory tracking.
*/
makeTensorFromDataId(dataId, shape, dtype, backend) {
dtype = dtype || 'float32';
const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
this.incRef(t, backend);
return t;
}
makeVariable(initialValue, trainable = true, name, dtype) {
name = name || this.nextVariableId().toString();
if (dtype != null && dtype !== initialValue.dtype) {
initialValue = initialValue.cast(dtype);
}
const v = new Variable(initialValue, trainable, name, this.nextTensorId());
if (this.state.registeredVariables[v.name] != null) {
throw new Error(`Variable with name ${v.name} was already registered`);
}
this.state.registeredVariables[v.name] = v;
this.incRef(v, this.backend);
return v;
}
incRef(a, backend) {
const refCount = this.state.tensorInfo.has(a.dataId) ?
this.state.tensorInfo.get(a.dataId).refCount :
0;
this.state.numTensors++;
if (a.dtype === 'string') {
this.state.numStringTensors++;
}
if (refCount === 0) {
this.state.numDataBuffers++;
// Bytes for complex numbers are counted by their components. Bytes for
// string tensors are counted when writing values.
let bytes = 0;
if (a.dtype !== 'complex64' && a.dtype !== 'string') {
bytes = a.size * bytesPerElement(a.dtype);
}
this.state.tensorInfo.set(a.dataId, {
backend: backend || this.backend,
dtype: a.dtype,
shape: a.shape,
bytes,
refCount: 0
});
this.state.numBytes += bytes;
}
this.state.tensorInfo.get(a.dataId).refCount++;
if (!(a instanceof Variable)) {
this.track(a);
}
}
disposeTensor(a) {
if (!this.state.tensorInfo.has(a.dataId)) {
return;
}
this.state.numTensors--;
if (a.dtype === 'string') {
this.state.numStringTensors--;
}
const info = this.state.tensorInfo.get(a.dataId);
const refCount = info.refCount;
if (refCount <= 1) {
// Don't count bytes for complex numbers as they are counted by their
// components.
if (a.dtype !== 'complex64') {
this.state.numBytes -= info.bytes;
}
this.state.numDataBuffers--;
info.backend.disposeData(a.dataId);
this.state.tensorInfo.delete(a.dataId);
}
else {
this.state.tensorInfo.get(a.dataId).refCount--;
}
// TODO(nsthorat): Construct an error and save the stack trace for
// debugging when in debug mode. Creating a stack trace is too expensive
// to do unconditionally.
}
disposeVariables() {
for (const varName in this.state.registeredVariables) {
const v = this.state.registeredVariables[varName];
this.disposeVariable(v);
}
}
disposeVariable(v) {
this.disposeTensor(v);
if (this.state.registeredVariables[v.name] != null) {
delete this.state.registeredVariables[v.name];
}
}
memory() {
const info = this.backend.memory();
info.numTensors = this.state.numTensors;
info.numDataBuffers = this.state.numDataBuffers;
info.numBytes = this.state.numBytes;
if (this.state.numStringTensors > 0) {
info.unreliable = true;
if (info.reasons == null) {
info.reasons = [];
}
info.reasons.push('Memory usage by string tensors is approximate ' +
'(2