This file has been truncated, but you can view the full file.
/* prebuilt es */ | |
/** | |
* @license | |
* Copyright 2020 Google LLC. All Rights Reserved. | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
* ============================================================================= | |
*/ | |
/** | |
* @license | |
* Copyright 2017 Google LLC. All Rights Reserved. | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
* ============================================================================= | |
*/ | |
// Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true. | |
const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags'; | |
/** | |
* The environment contains evaluated flags as well as the registered platform. | |
* This is always used as a global singleton and can be retrieved with | |
* `tf.env()`. | |
*/ | |
/** @doc {heading: 'Environment'} */ | |
class Environment { | |
// tslint:disable-next-line: no-any | |
constructor(global) { | |
this.global = global; | |
this.flags = {}; | |
this.flagRegistry = {}; | |
this.urlFlags = {}; | |
this.populateURLFlags(); | |
} | |
setPlatform(platformName, platform) { | |
if (this.platform != null) { | |
console.warn(`Platform ${this.platformName} has already been set. ` + | |
`Overwriting the platform with ${platform}.`); | |
} | |
this.platformName = platformName; | |
this.platform = platform; | |
} | |
registerFlag(flagName, evaluationFn, setHook) { | |
this.flagRegistry[flagName] = { evaluationFn, setHook }; | |
// Override the flag value from the URL. This has to happen here because the | |
// environment is initialized before flags get registered. | |
if (this.urlFlags[flagName] != null) { | |
const flagValue = this.urlFlags[flagName]; | |
console.warn(`Setting feature override from URL ${flagName}: ${flagValue}.`); | |
this.set(flagName, flagValue); | |
} | |
} | |
async getAsync(flagName) { | |
if (flagName in this.flags) { | |
return this.flags[flagName]; | |
} | |
this.flags[flagName] = await this.evaluateFlag(flagName); | |
return this.flags[flagName]; | |
} | |
get(flagName) { | |
if (flagName in this.flags) { | |
return this.flags[flagName]; | |
} | |
const flagValue = this.evaluateFlag(flagName); | |
if (flagValue instanceof Promise) { | |
throw new Error(`Flag ${flagName} cannot be synchronously evaluated. ` + | |
`Please use getAsync() instead.`); | |
} | |
this.flags[flagName] = flagValue; | |
return this.flags[flagName]; | |
} | |
getNumber(flagName) { | |
return this.get(flagName); | |
} | |
getBool(flagName) { | |
return this.get(flagName); | |
} | |
getFlags() { | |
return this.flags; | |
} | |
// For backwards compatibility. | |
get features() { | |
return this.flags; | |
} | |
set(flagName, value) { | |
if (this.flagRegistry[flagName] == null) { | |
throw new Error(`Cannot set flag ${flagName} as it has not been registered.`); | |
} | |
this.flags[flagName] = value; | |
if (this.flagRegistry[flagName].setHook != null) { | |
this.flagRegistry[flagName].setHook(value); | |
} | |
} | |
evaluateFlag(flagName) { | |
if (this.flagRegistry[flagName] == null) { | |
throw new Error(`Cannot evaluate flag '${flagName}': no evaluation function found.`); | |
} | |
return this.flagRegistry[flagName].evaluationFn(); | |
} | |
setFlags(flags) { | |
this.flags = Object.assign({}, flags); | |
} | |
reset() { | |
this.flags = {}; | |
this.urlFlags = {}; | |
this.populateURLFlags(); | |
} | |
populateURLFlags() { | |
if (typeof this.global === 'undefined' || | |
typeof this.global.location === 'undefined' || | |
typeof this.global.location.search === 'undefined') { | |
return; | |
} | |
const urlParams = getQueryParams(this.global.location.search); | |
if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) { | |
const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(','); | |
keyValues.forEach(keyValue => { | |
const [key, value] = keyValue.split(':'); | |
this.urlFlags[key] = parseValue(key, value); | |
}); | |
} | |
} | |
} | |
function getQueryParams(queryString) { | |
const params = {}; | |
queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (s, ...t) => { | |
decodeParam(params, t[0], t[1]); | |
return t.join('='); | |
}); | |
return params; | |
} | |
function decodeParam(params, name, value) { | |
params[decodeURIComponent(name)] = decodeURIComponent(value || ''); | |
} | |
function parseValue(flagName, value) { | |
value = value.toLowerCase(); | |
if (value === 'true' || value === 'false') { | |
return value === 'true'; | |
} | |
else if (`${+value}` === value) { | |
return +value; | |
} | |
throw new Error(`Could not parse value flag value ${value} for flag ${flagName}.`); | |
} | |
/** | |
* Returns the current environment (a global singleton). | |
* | |
* The environment object contains the evaluated feature values as well as the | |
* active platform. | |
*/ | |
/** @doc {heading: 'Environment'} */ | |
function env() { | |
return ENV; | |
} | |
let ENV = null; | |
function setEnvironmentGlobal(environment) { | |
ENV = environment; | |
} | |
/** | |
* @license | |
* Copyright 2020 Google LLC. All Rights Reserved. | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
* ============================================================================= | |
*/ | |
// Note that the identifier globalNameSpace is scoped to this module, but will | |
// always resolve to the same global object regardless of how the module is | |
// resolved. | |
// tslint:disable-next-line:no-any | |
let globalNameSpace; | |
// tslint:disable-next-line:no-any | |
function getGlobalNamespace() { | |
if (globalNameSpace == null) { | |
// tslint:disable-next-line:no-any | |
let ns; | |
if (typeof (window) !== 'undefined') { | |
ns = window; | |
} | |
else if (typeof (global) !== 'undefined') { | |
ns = global; | |
} | |
else if (typeof (process) !== 'undefined') { | |
ns = process; | |
} | |
else if (typeof (self) !== 'undefined') { | |
ns = self; | |
} | |
else { | |
throw new Error('Could not find a global object'); | |
} | |
globalNameSpace = ns; | |
} | |
return globalNameSpace; | |
} | |
// tslint:disable-next-line:no-any | |
function getGlobalMap() { | |
const ns = getGlobalNamespace(); | |
if (ns._tfGlobals == null) { | |
ns._tfGlobals = new Map(); | |
} | |
return ns._tfGlobals; | |
} | |
/** | |
* Returns a globally accessible 'singleton' object. | |
* | |
* @param key the name of the object | |
* @param init a function to initialize to initialize this object | |
* the first time it is fetched. | |
*/ | |
function getGlobal(key, init) { | |
const globalMap = getGlobalMap(); | |
if (globalMap.has(key)) { | |
return globalMap.get(key); | |
} | |
else { | |
const singleton = init(); | |
globalMap.set(key, singleton); | |
return globalMap.get(key); | |
} | |
} | |
/** | |
* @license | |
* Copyright 2019 Google LLC. All Rights Reserved. | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
* ============================================================================= | |
*/ | |
const kernelRegistry = getGlobal('kernelRegistry', () => new Map()); | |
const gradRegistry = getGlobal('gradRegistry', () => new Map()); | |
/** | |
* Returns the kernel function (code) associated with the provided names. | |
* | |
* @param kernelName The official name of the kernel. | |
* @param backendName The official name of the backend. | |
*/ | |
function getKernel(kernelName, backendName) { | |
const key = makeKey(kernelName, backendName); | |
return kernelRegistry.get(key); | |
} | |
/** | |
* Returns the registered gradient info associated with the provided kernel. | |
* @param kernelName The official TF kernel name. | |
*/ | |
function getGradient(kernelName) { | |
return gradRegistry.get(kernelName); | |
} | |
function getKernelsForBackend(backendName) { | |
const it = kernelRegistry.entries(); | |
const result = []; | |
while (true) { | |
const { done, value } = it.next(); | |
if (done) { | |
break; | |
} | |
const [key, config] = value; | |
const [backend,] = key.split('_'); | |
if (backend === backendName) { | |
result.push(config); | |
} | |
} | |
return result; | |
} | |
/** | |
* Registers the function (forward pass) for the kernel in a global registry. | |
* | |
* @param config A config object with the following properties: | |
* - `kernelName` The official name of the kernel. | |
* - `backendName` The official name of the backend. | |
* - `kernelFunc` The function to run during the forward pass of the kernel. | |
* - `setupFunc` Optional. Gets called once, after the backend initializes. | |
* - `disposeFunc` Optional. Gets called once, right before the backend is | |
* disposed. | |
*/ | |
function registerKernel(config) { | |
const { kernelName, backendName } = config; | |
const key = makeKey(kernelName, backendName); | |
if (kernelRegistry.has(key)) { | |
console.warn(`The kernel '${kernelName}' for backend ` + | |
`'${backendName}' is already registered`); | |
} | |
kernelRegistry.set(key, config); | |
} | |
/** | |
* Registers a gradient function for a given kernel in the global registry, | |
* to be used during the back-propagation of that kernel. | |
* | |
* @param config An object with the following properties: | |
* - `kernelName` The name of the kernel that the gradient function is for. | |
* - `gradFunc` The function to run during back-propagation. | |
*/ | |
function registerGradient(config) { | |
const { kernelName } = config; | |
if (gradRegistry.has(kernelName)) { | |
// TODO (yassogba) after 3.0 assess whether we need to keep this gated | |
// to debug mode. | |
if (env().getBool('DEBUG')) { | |
console.warn(`Overriding the gradient for '${kernelName}'`); | |
} | |
} | |
gradRegistry.set(kernelName, config); | |
} | |
/** | |
* Removes the kernel function from the registry. | |
* | |
* @param kernelName The official name of the kernel. | |
* @param backendName The official name of the backend. | |
* | |
*/ | |
function unregisterKernel(kernelName, backendName) { | |
const key = makeKey(kernelName, backendName); | |
if (!kernelRegistry.has(key)) { | |
throw new Error(`The kernel '${kernelName}' for backend ` + | |
`'${backendName}' is not registered`); | |
} | |
kernelRegistry.delete(key); | |
} | |
/** Removes the registered gradient from the global registry. */ | |
function unregisterGradient(kernelName) { | |
if (!gradRegistry.has(kernelName)) { | |
throw new Error(`The gradient '${kernelName}' for backend is not registered`); | |
} | |
gradRegistry.delete(kernelName); | |
} | |
function makeKey(kernelName, backendName) { | |
return `${backendName}_${kernelName}`; | |
} | |
/** | |
* @license | |
* Copyright 2017 Google LLC. All Rights Reserved. | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
* ============================================================================= | |
*/ | |
/** | |
* Shuffles the array in-place using Fisher-Yates algorithm. | |
* | |
* ```js | |
* const a = [1, 2, 3, 4, 5]; | |
* tf.util.shuffle(a); | |
* console.log(a); | |
* ``` | |
* | |
* @param array The array to shuffle in-place. | |
*/ | |
/** @doc {heading: 'Util', namespace: 'util'} */ | |
// tslint:disable-next-line:no-any | |
function shuffle(array) { | |
let counter = array.length; | |
let temp = 0; | |
let index = 0; | |
// While there are elements in the array | |
while (counter > 0) { | |
// Pick a random index | |
index = (Math.random() * counter) | 0; | |
// Decrease counter by 1 | |
counter--; | |
// And swap the last element with it | |
temp = array[counter]; | |
array[counter] = array[index]; | |
array[index] = temp; | |
} | |
} | |
/** Clamps a value to a specified range. */ | |
function clamp(min, x, max) { | |
return Math.max(min, Math.min(x, max)); | |
} | |
function nearestLargerEven(val) { | |
return val % 2 === 0 ? val : val + 1; | |
} | |
function sum(arr) { | |
let sum = 0; | |
for (let i = 0; i < arr.length; i++) { | |
sum += arr[i]; | |
} | |
return sum; | |
} | |
/** | |
* Returns a sample from a uniform [a, b) distribution. | |
* | |
* @param a The minimum support (inclusive). | |
* @param b The maximum support (exclusive). | |
* @return A pseudorandom number on the half-open interval [a,b). | |
*/ | |
function randUniform(a, b) { | |
const r = Math.random(); | |
return (b * r) + (1 - r) * a; | |
} | |
/** Returns the squared Euclidean distance between two vectors. */ | |
function distSquared(a, b) { | |
let result = 0; | |
for (let i = 0; i < a.length; i++) { | |
const diff = Number(a[i]) - Number(b[i]); | |
result += diff * diff; | |
} | |
return result; | |
} | |
/** | |
* Asserts that the expression is true. Otherwise throws an error with the | |
* provided message. | |
* | |
* ```js | |
* const x = 2; | |
* tf.util.assert(x === 2, 'x is not 2'); | |
* ``` | |
* | |
* @param expr The expression to assert (as a boolean). | |
* @param msg A function that returns the message to report when throwing an | |
* error. We use a function for performance reasons. | |
*/ | |
/** @doc {heading: 'Util', namespace: 'util'} */ | |
function assert(expr, msg) { | |
if (!expr) { | |
throw new Error(typeof msg === 'string' ? msg : msg()); | |
} | |
} | |
function assertShapesMatch(shapeA, shapeB, errorMessagePrefix = '') { | |
assert(arraysEqual(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`); | |
} | |
function assertNonNull(a) { | |
assert(a != null, () => `The input to the tensor constructor must be a non-null value.`); | |
} | |
// NOTE: We explicitly type out what T extends instead of any so that | |
// util.flatten on a nested array of number doesn't try to infer T as a | |
// number[][], causing us to explicitly type util.flatten<number>(). | |
/** | |
* Flattens an arbitrarily nested array. | |
* | |
* ```js | |
* const a = [[1, 2], [3, 4], [5, [6, [7]]]]; | |
* const flat = tf.util.flatten(a); | |
* console.log(flat); | |
* ``` | |
* | |
* @param arr The nested array to flatten. | |
* @param result The destination array which holds the elements. | |
* @param skipTypedArray If true, avoids flattening the typed arrays. Defaults | |
* to false. | |
*/ | |
/** @doc {heading: 'Util', namespace: 'util'} */ | |
function flatten(arr, result = [], skipTypedArray = false) { | |
if (result == null) { | |
result = []; | |
} | |
if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) { | |
for (let i = 0; i < arr.length; ++i) { | |
flatten(arr[i], result, skipTypedArray); | |
} | |
} | |
else { | |
result.push(arr); | |
} | |
return result; | |
} | |
/** | |
* Returns the size (number of elements) of the tensor given its shape. | |
* | |
* ```js | |
* const shape = [3, 4, 2]; | |
* const size = tf.util.sizeFromShape(shape); | |
* console.log(size); | |
* ``` | |
*/ | |
/** @doc {heading: 'Util', namespace: 'util'} */ | |
function sizeFromShape(shape) { | |
if (shape.length === 0) { | |
// Scalar. | |
return 1; | |
} | |
let size = shape[0]; | |
for (let i = 1; i < shape.length; i++) { | |
size *= shape[i]; | |
} | |
return size; | |
} | |
function isScalarShape(shape) { | |
return shape.length === 0; | |
} | |
function arraysEqual(n1, n2) { | |
if (n1 === n2) { | |
return true; | |
} | |
if (n1 == null || n2 == null) { | |
return false; | |
} | |
if (n1.length !== n2.length) { | |
return false; | |
} | |
for (let i = 0; i < n1.length; i++) { | |
if (n1[i] !== n2[i]) { | |
return false; | |
} | |
} | |
return true; | |
} | |
function isInt(a) { | |
return a % 1 === 0; | |
} | |
function tanh(x) { | |
// tslint:disable-next-line:no-any | |
if (Math.tanh != null) { | |
// tslint:disable-next-line:no-any | |
return Math.tanh(x); | |
} | |
if (x === Infinity) { | |
return 1; | |
} | |
else if (x === -Infinity) { | |
return -1; | |
} | |
else { | |
const e2x = Math.exp(2 * x); | |
return (e2x - 1) / (e2x + 1); | |
} | |
} | |
function sizeToSquarishShape(size) { | |
const width = Math.ceil(Math.sqrt(size)); | |
return [width, Math.ceil(size / width)]; | |
} | |
/** | |
* Creates a new array with randomized indicies to a given quantity. | |
* | |
* ```js | |
* const randomTen = tf.util.createShuffledIndices(10); | |
* console.log(randomTen); | |
* ``` | |
* | |
* @param number Quantity of how many shuffled indicies to create. | |
*/ | |
/** @doc {heading: 'Util', namespace: 'util'} */ | |
function createShuffledIndices(n) { | |
const shuffledIndices = new Uint32Array(n); | |
for (let i = 0; i < n; ++i) { | |
shuffledIndices[i] = i; | |
} | |
shuffle(shuffledIndices); | |
return shuffledIndices; | |
} | |
function rightPad(a, size) { | |
if (size <= a.length) { | |
return a; | |
} | |
return a + ' '.repeat(size - a.length); | |
} | |
function repeatedTry(checkFn, delayFn = (counter) => 0, maxCounter) { | |
return new Promise((resolve, reject) => { | |
let tryCount = 0; | |
const tryFn = () => { | |
if (checkFn()) { | |
resolve(); | |
return; | |
} | |
tryCount++; | |
const nextBackoff = delayFn(tryCount); | |
if (maxCounter != null && tryCount >= maxCounter) { | |
reject(); | |
return; | |
} | |
setTimeout(tryFn, nextBackoff); | |
}; | |
tryFn(); | |
}); | |
} | |
/** | |
* Given the full size of the array and a shape that may contain -1 as the | |
* implicit dimension, returns the inferred shape where -1 is replaced. | |
* E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3]. | |
* | |
* @param shape The shape, which may contain -1 in some dimension. | |
* @param size The full size (number of elements) of the array. | |
* @return The inferred shape where -1 is replaced with the inferred size. | |
*/ | |
function inferFromImplicitShape(shape, size) { | |
let shapeProd = 1; | |
let implicitIdx = -1; | |
for (let i = 0; i < shape.length; ++i) { | |
if (shape[i] >= 0) { | |
shapeProd *= shape[i]; | |
} | |
else if (shape[i] === -1) { | |
if (implicitIdx !== -1) { | |
throw Error(`Shapes can only have 1 implicit size. ` + | |
`Found -1 at dim ${implicitIdx} and dim ${i}`); | |
} | |
implicitIdx = i; | |
} | |
else if (shape[i] < 0) { | |
throw Error(`Shapes can not be < 0. Found ${shape[i]} at dim ${i}`); | |
} | |
} | |
if (implicitIdx === -1) { | |
if (size > 0 && size !== shapeProd) { | |
throw Error(`Size(${size}) must match the product of shape ${shape}`); | |
} | |
return shape; | |
} | |
if (shapeProd === 0) { | |
throw Error(`Cannot infer the missing size in [${shape}] when ` + | |
`there are 0 elements`); | |
} | |
if (size % shapeProd !== 0) { | |
throw Error(`The implicit shape can't be a fractional number. ` + | |
`Got ${size} / ${shapeProd}`); | |
} | |
const newShape = shape.slice(); | |
newShape[implicitIdx] = size / shapeProd; | |
return newShape; | |
} | |
function parseAxisParam(axis, shape) { | |
const rank = shape.length; | |
// Normalize input | |
axis = axis == null ? shape.map((s, i) => i) : [].concat(axis); | |
// Check for valid range | |
assert(axis.every(ax => ax >= -rank && ax < rank), () => `All values in axis param must be in range [-${rank}, ${rank}) but ` + | |
`got axis ${axis}`); | |
// Check for only integers | |
assert(axis.every(ax => isInt(ax)), () => `All values in axis param must be integers but ` + | |
`got axis ${axis}`); | |
// Handle negative axis. | |
return axis.map(a => a < 0 ? rank + a : a); | |
} | |
/** Reduces the shape by removing all dimensions of shape 1. */ | |
function squeezeShape(shape, axis) { | |
const newShape = []; | |
const keptDims = []; | |
const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0; | |
const axes = (axis == null || isEmptyArray) ? | |
null : | |
parseAxisParam(axis, shape).sort(); | |
let j = 0; | |
for (let i = 0; i < shape.length; ++i) { | |
if (axes != null) { | |
if (axes[j] === i && shape[i] !== 1) { | |
throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`); | |
} | |
if ((axes[j] == null || axes[j] > i) && shape[i] === 1) { | |
newShape.push(shape[i]); | |
keptDims.push(i); | |
} | |
if (axes[j] <= i) { | |
j++; | |
} | |
} | |
if (shape[i] !== 1) { | |
newShape.push(shape[i]); | |
keptDims.push(i); | |
} | |
} | |
return { newShape, keptDims }; | |
} | |
function getTypedArrayFromDType(dtype, size) { | |
let values = null; | |
if (dtype == null || dtype === 'float32') { | |
values = new Float32Array(size); | |
} | |
else if (dtype === 'int32') { | |
values = new Int32Array(size); | |
} | |
else if (dtype === 'bool') { | |
values = new Uint8Array(size); | |
} | |
else { | |
throw new Error(`Unknown data type ${dtype}`); | |
} | |
return values; | |
} | |
function getArrayFromDType(dtype, size) { | |
let values = null; | |
if (dtype == null || dtype === 'float32') { | |
values = new Float32Array(size); | |
} | |
else if (dtype === 'int32') { | |
values = new Int32Array(size); | |
} | |
else if (dtype === 'bool') { | |
values = new Uint8Array(size); | |
} | |
else if (dtype === 'string') { | |
values = new Array(size); | |
} | |
else { | |
throw new Error(`Unknown data type ${dtype}`); | |
} | |
return values; | |
} | |
function checkConversionForErrors(vals, dtype) { | |
for (let i = 0; i < vals.length; i++) { | |
const num = vals[i]; | |
if (isNaN(num) || !isFinite(num)) { | |
throw Error(`A tensor of type ${dtype} being uploaded contains ${num}.`); | |
} | |
} | |
} | |
/** Returns true if the dtype is valid. */ | |
function isValidDtype(dtype) { | |
return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' || | |
dtype === 'int32' || dtype === 'string'; | |
} | |
/** | |
* Returns true if the new type can't encode the old type without loss of | |
* precision. | |
*/ | |
function hasEncodingLoss(oldType, newType) { | |
if (newType === 'complex64') { | |
return false; | |
} | |
if (newType === 'float32' && oldType !== 'complex64') { | |
return false; | |
} | |
if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') { | |
return false; | |
} | |
if (newType === 'bool' && oldType === 'bool') { | |
return false; | |
} | |
return true; | |
} | |
function isTypedArray(a) { | |
return a instanceof Float32Array || a instanceof Int32Array || | |
a instanceof Uint8Array; | |
} | |
function bytesPerElement(dtype) { | |
if (dtype === 'float32' || dtype === 'int32') { | |
return 4; | |
} | |
else if (dtype === 'complex64') { | |
return 8; | |
} | |
else if (dtype === 'bool') { | |
return 1; | |
} | |
else { | |
throw new Error(`Unknown dtype ${dtype}`); | |
} | |
} | |
/** | |
* Returns the approximate number of bytes allocated in the string array - 2 | |
* bytes per character. Computing the exact bytes for a native string in JS is | |
* not possible since it depends on the encoding of the html page that serves | |
* the website. | |
*/ | |
function bytesFromStringArray(arr) { | |
if (arr == null) { | |
return 0; | |
} | |
let bytes = 0; | |
arr.forEach(x => bytes += x.length); | |
return bytes; | |
} | |
/** Returns true if the value is a string. */ | |
function isString(value) { | |
return typeof value === 'string' || value instanceof String; | |
} | |
function isBoolean(value) { | |
return typeof value === 'boolean'; | |
} | |
function isNumber(value) { | |
return typeof value === 'number'; | |
} | |
function inferDtype(values) { | |
if (Array.isArray(values)) { | |
return inferDtype(values[0]); | |
} | |
if (values instanceof Float32Array) { | |
return 'float32'; | |
} | |
else if (values instanceof Int32Array || values instanceof Uint8Array) { | |
return 'int32'; | |
} | |
else if (isNumber(values)) { | |
return 'float32'; | |
} | |
else if (isString(values)) { | |
return 'string'; | |
} | |
else if (isBoolean(values)) { | |
return 'bool'; | |
} | |
return 'float32'; | |
} | |
function isFunction(f) { | |
return !!(f && f.constructor && f.call && f.apply); | |
} | |
function nearestDivisor(size, start) { | |
for (let i = start; i < size; ++i) { | |
if (size % i === 0) { | |
return i; | |
} | |
} | |
return size; | |
} | |
function computeStrides(shape) { | |
const rank = shape.length; | |
if (rank < 2) { | |
return []; | |
} | |
// Last dimension has implicit stride of 1, thus having D-1 (instead of D) | |
// strides. | |
const strides = new Array(rank - 1); | |
strides[rank - 2] = shape[rank - 1]; | |
for (let i = rank - 3; i >= 0; --i) { | |
strides[i] = strides[i + 1] * shape[i + 1]; | |
} | |
return strides; | |
} | |
function toTypedArray(a, dtype) { | |
if (dtype === 'string') { | |
throw new Error('Cannot convert a string[] to a TypedArray'); | |
} | |
if (Array.isArray(a)) { | |
a = flatten(a); | |
} | |
if (env().getBool('DEBUG')) { | |
checkConversionForErrors(a, dtype); | |
} | |
if (noConversionNeeded(a, dtype)) { | |
return a; | |
} | |
if (dtype == null || dtype === 'float32' || dtype === 'complex64') { | |
return new Float32Array(a); | |
} | |
else if (dtype === 'int32') { | |
return new Int32Array(a); | |
} | |
else if (dtype === 'bool') { | |
const bool = new Uint8Array(a.length); | |
for (let i = 0; i < bool.length; ++i) { | |
if (Math.round(a[i]) !== 0) { | |
bool[i] = 1; | |
} | |
} | |
return bool; | |
} | |
else { | |
throw new Error(`Unknown data type ${dtype}`); | |
} | |
} | |
function createNestedArray(offset, shape, a) { | |
const ret = new Array(); | |
if (shape.length === 1) { | |
const d = shape[0]; | |
for (let i = 0; i < d; i++) { | |
ret[i] = a[offset + i]; | |
} | |
} | |
else { | |
const d = shape[0]; | |
const rest = shape.slice(1); | |
const len = rest.reduce((acc, c) => acc * c); | |
for (let i = 0; i < d; i++) { | |
ret[i] = createNestedArray(offset + i * len, rest, a); | |
} | |
} | |
return ret; | |
} | |
// Provide a nested array of TypedArray in given shape. | |
function toNestedArray(shape, a) { | |
if (shape.length === 0) { | |
// Scalar type should return a single number. | |
return a[0]; | |
} | |
const size = shape.reduce((acc, c) => acc * c); | |
if (size === 0) { | |
// A tensor with shape zero should be turned into empty list. | |
return []; | |
} | |
if (size !== a.length) { | |
throw new Error(`[${shape}] does not match the input size ${a.length}.`); | |
} | |
return createNestedArray(0, shape, a); | |
} | |
function noConversionNeeded(a, dtype) { | |
return (a instanceof Float32Array && dtype === 'float32') || | |
(a instanceof Int32Array && dtype === 'int32') || | |
(a instanceof Uint8Array && dtype === 'bool'); | |
} | |
function makeOnesTypedArray(size, dtype) { | |
const array = makeZerosTypedArray(size, dtype); | |
for (let i = 0; i < array.length; i++) { | |
array[i] = 1; | |
} | |
return array; | |
} | |
function makeZerosTypedArray(size, dtype) { | |
if (dtype == null || dtype === 'float32' || dtype === 'complex64') { | |
return new Float32Array(size); | |
} | |
else if (dtype === 'int32') { | |
return new Int32Array(size); | |
} | |
else if (dtype === 'bool') { | |
return new Uint8Array(size); | |
} | |
else { | |
throw new Error(`Unknown data type ${dtype}`); | |
} | |
} | |
/** | |
* Make nested `TypedArray` filled with zeros. | |
* @param shape The shape information for the nested array. | |
* @param dtype dtype of the array element. | |
*/ | |
function makeZerosNestedTypedArray(shape, dtype) { | |
const size = shape.reduce((prev, curr) => prev * curr, 1); | |
if (dtype == null || dtype === 'float32') { | |
return toNestedArray(shape, new Float32Array(size)); | |
} | |
else if (dtype === 'int32') { | |
return toNestedArray(shape, new Int32Array(size)); | |
} | |
else if (dtype === 'bool') { | |
return toNestedArray(shape, new Uint8Array(size)); | |
} | |
else { | |
throw new Error(`Unknown data type ${dtype}`); | |
} | |
} | |
/** | |
* Returns the current high-resolution time in milliseconds relative to an | |
* arbitrary time in the past. It works across different platforms (node.js, | |
* browsers). | |
* | |
* ```js | |
* console.log(tf.util.now()); | |
* ``` | |
*/ | |
/** @doc {heading: 'Util', namespace: 'util'} */ | |
function now() { | |
return env().platform.now(); | |
} | |
function assertNonNegativeIntegerDimensions(shape) { | |
shape.forEach(dimSize => { | |
assert(Number.isInteger(dimSize) && dimSize >= 0, () => `Tensor must have a shape comprised of positive integers but got ` + | |
`shape [${shape}].`); | |
}); | |
} | |
/** | |
* Returns a platform-specific implementation of | |
* [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API). | |
* | |
* If `fetch` is defined on the global object (`window`, `process`, etc.), | |
* `tf.util.fetch` returns that function. | |
* | |
* If not, `tf.util.fetch` returns a platform-specific solution. | |
* | |
* ```js | |
* const resource = await tf.util.fetch('https://unpkg.com/@tensorflow/tfjs'); | |
* // handle response | |
* ``` | |
*/ | |
/** @doc {heading: 'Util'} */ | |
function fetch$1(path, requestInits) { | |
return env().platform.fetch(path, requestInits); | |
} | |
/** | |
* Encodes the provided string into bytes using the provided encoding scheme. | |
* | |
* @param s The string to encode. | |
* @param encoding The encoding scheme. Defaults to utf-8. | |
* | |
*/ | |
/** @doc {heading: 'Util'} */ | |
function encodeString(s, encoding = 'utf-8') { | |
encoding = encoding || 'utf-8'; | |
return env().platform.encode(s, encoding); | |
} | |
/** | |
* Decodes the provided bytes into a string using the provided encoding scheme. | |
* @param bytes The bytes to decode. | |
* | |
* @param encoding The encoding scheme. Defaults to utf-8. | |
*/ | |
/** @doc {heading: 'Util'} */ | |
function decodeString(bytes, encoding = 'utf-8') { | |
encoding = encoding || 'utf-8'; | |
return env().platform.decode(bytes, encoding); | |
} | |
/** | |
* Computes flat index for a given location (multidimentionsal index) in a | |
* Tensor/multidimensional array. | |
* | |
* @param locs Location in the tensor. | |
* @param rank Rank of the tensor. | |
* @param strides Tensor strides. | |
*/ | |
function locToIndex(locs, rank, strides) { | |
if (rank === 0) { | |
return 0; | |
} | |
else if (rank === 1) { | |
return locs[0]; | |
} | |
let index = locs[locs.length - 1]; | |
for (let i = 0; i < locs.length - 1; ++i) { | |
index += strides[i] * locs[i]; | |
} | |
return index; | |
} | |
/** | |
* Computes the location (multidimensional index) in a tensor/multidimentional | |
* array for a given flat index. | |
* | |
* @param index Index in flat array. | |
* @param rank Rank of tensor. | |
* @param strides Strides of tensor. | |
*/ | |
function indexToLoc(index, rank, strides) { | |
if (rank === 0) { | |
return []; | |
} | |
else if (rank === 1) { | |
return [index]; | |
} | |
const locs = new Array(rank); | |
for (let i = 0; i < locs.length - 1; ++i) { | |
locs[i] = Math.floor(index / strides[i]); | |
index -= locs[i] * strides[i]; | |
} | |
locs[locs.length - 1] = index; | |
return locs; | |
} | |
var util = { | |
__proto__: null, | |
shuffle: shuffle, | |
clamp: clamp, | |
nearestLargerEven: nearestLargerEven, | |
sum: sum, | |
randUniform: randUniform, | |
distSquared: distSquared, | |
assert: assert, | |
assertShapesMatch: assertShapesMatch, | |
assertNonNull: assertNonNull, | |
flatten: flatten, | |
sizeFromShape: sizeFromShape, | |
isScalarShape: isScalarShape, | |
arraysEqual: arraysEqual, | |
isInt: isInt, | |
tanh: tanh, | |
sizeToSquarishShape: sizeToSquarishShape, | |
createShuffledIndices: createShuffledIndices, | |
rightPad: rightPad, | |
repeatedTry: repeatedTry, | |
inferFromImplicitShape: inferFromImplicitShape, | |
parseAxisParam: parseAxisParam, | |
squeezeShape: squeezeShape, | |
getTypedArrayFromDType: getTypedArrayFromDType, | |
getArrayFromDType: getArrayFromDType, | |
checkConversionForErrors: checkConversionForErrors, | |
isValidDtype: isValidDtype, | |
hasEncodingLoss: hasEncodingLoss, | |
isTypedArray: isTypedArray, | |
bytesPerElement: bytesPerElement, | |
bytesFromStringArray: bytesFromStringArray, | |
isString: isString, | |
isBoolean: isBoolean, | |
isNumber: isNumber, | |
inferDtype: inferDtype, | |
isFunction: isFunction, | |
nearestDivisor: nearestDivisor, | |
computeStrides: computeStrides, | |
toTypedArray: toTypedArray, | |
toNestedArray: toNestedArray, | |
makeOnesTypedArray: makeOnesTypedArray, | |
makeZerosTypedArray: makeZerosTypedArray, | |
makeZerosNestedTypedArray: makeZerosNestedTypedArray, | |
now: now, | |
assertNonNegativeIntegerDimensions: assertNonNegativeIntegerDimensions, | |
fetch: fetch$1, | |
encodeString: encodeString, | |
decodeString: decodeString, | |
locToIndex: locToIndex, | |
indexToLoc: indexToLoc | |
}; | |
/** | |
* @license | |
* Copyright 2018 Google LLC. All Rights Reserved. | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
* ============================================================================= | |
*/ | |
class Profiler { | |
constructor(backendTimer, logger) { | |
this.backendTimer = backendTimer; | |
this.logger = logger; | |
if (logger == null) { | |
this.logger = new Logger(); | |
} | |
} | |
profileKernel(kernelName, inputs, f) { | |
let outputs; | |
const holdResultWrapperFn = () => { | |
outputs = f(); | |
}; | |
const timer = this.backendTimer.time(holdResultWrapperFn); | |
outputs.forEach(r => { | |
// Dangling promise here because we don't want to propagate up | |
// asynchronicity. | |
r.data().then(vals => { | |
checkComputationForErrors(vals, r.dtype, kernelName); | |
timer.then(timing => { | |
let extraInfo = ''; | |
if (timing.getExtraProfileInfo != null) { | |
extraInfo = timing.getExtraProfileInfo(); | |
} | |
this.logger.logKernelProfile(kernelName, r, vals, timing.kernelMs, inputs, extraInfo); | |
}); | |
}); | |
}); | |
return outputs; | |
} | |
} | |
function checkComputationForErrors(vals, dtype, kernelName) { | |
if (dtype !== 'float32') { | |
// Only floating point computations will generate NaN values | |
return false; | |
} | |
for (let i = 0; i < vals.length; i++) { | |
const num = vals[i]; | |
if (isNaN(num) || !isFinite(num)) { | |
// Throwing custom exception so behavior is testable. | |
console.warn(`Found ${num} in the result of '${kernelName}'`); | |
return true; | |
} | |
} | |
return false; | |
} | |
class Logger { | |
logKernelProfile(name, result, vals, timeMs, inputs, extraInfo) { | |
const time = typeof timeMs === 'number' ? rightPad(`${timeMs}ms`, 9) : | |
timeMs['error']; | |
const paddedName = rightPad(name, 25); | |
const rank = result.rank; | |
const size = result.size; | |
const shape = rightPad(result.shape.toString(), 14); | |
let inputShapesDescription = ''; | |
for (const name in inputs) { | |
const input = inputs[name]; | |
// The input might be a non-tensor (e.g HTMLImageElement), in which case | |
// we claim the output shape as input shape. | |
const inputShape = input.shape || result.shape; | |
const inputRank = inputShape.length; | |
inputShapesDescription += | |
`${name}: ${inputRank}D ${inputRank > 0 ? inputShape : ''} `; | |
} | |
console.log(`%c${paddedName}\t%c${time}\t%c${rank}D ${shape}\t%c${size}\t%c${inputShapesDescription}\t%c${extraInfo}`, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue'); | |
} | |
} | |
/** | |
* @license | |
* Copyright 2017 Google LLC. All Rights Reserved. | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
* ============================================================================= | |
*/ | |
/** | |
* Computes a list of TapeNodes that connect x to y, filtering everything else | |
* out and preserving the order of the original tape elements. | |
* | |
* @param tape The tape elements to filter. | |
* @param xs The input Tensors. | |
* @param y The output Tensor. | |
*/ | |
function getFilteredNodesXToY(tape, xs, y) { | |
// Forward pass to compute all the nodes and Tensors that are transitively a | |
// function of x. | |
const tensorsFromX = {}; | |
const nodesFromX = {}; | |
for (let i = 0; i < xs.length; i++) { | |
tensorsFromX[xs[i].id] = true; | |
} | |
for (let i = 0; i < tape.length; i++) { | |
const node = tape[i]; | |
const nodeInputs = node.inputs; | |
for (const inputName in nodeInputs) { | |
const input = nodeInputs[inputName]; | |
let anyInputFromX = false; | |
for (let j = 0; j < xs.length; j++) { | |
if (tensorsFromX[input.id]) { | |
node.outputs.forEach(output => tensorsFromX[output.id] = true); | |
anyInputFromX = true; | |
nodesFromX[node.id] = true; | |
break; | |
} | |
} | |
if (anyInputFromX) { | |
break; | |
} | |
} | |
} | |
// Backward pass to find all of the nodes and Tensors that lead to y. | |
const tensorsLeadToY = {}; | |
tensorsLeadToY[y.id] = true; | |
const nodesToY = {}; | |
for (let i = tape.length - 1; i >= 0; i--) { | |
const node = tape[i]; | |
const nodeInputs = node.inputs; | |
// If any of the outputs lead to y, mark all of the inputs as leading to y. | |
for (let j = 0; j < node.outputs.length; j++) { | |
if (tensorsLeadToY[node.outputs[j].id]) { | |
for (const inputName in nodeInputs) { | |
tensorsLeadToY[nodeInputs[inputName].id] = true; | |
nodesToY[node.id] = true; | |
} | |
break; | |
} | |
} | |
} | |
// Return the paths that come from x and lead to y. | |
const filteredTape = []; | |
for (let i = 0; i < tape.length; i++) { | |
const node = tape[i]; | |
if (nodesFromX[node.id] && nodesToY[node.id]) { | |
// Prune the inputs from the node that aren't a function of x. | |
const prunedInputs = {}; | |
for (const inputName in node.inputs) { | |
const nodeInput = node.inputs[inputName]; | |
if (tensorsFromX[nodeInput.id]) { | |
prunedInputs[inputName] = nodeInput; | |
} | |
} | |
// Copy the node and overwrite inputsAndArgs to the pruned version. | |
const prunedNode = Object.assign({}, node); | |
prunedNode.inputs = prunedInputs; | |
prunedNode.outputs = node.outputs; | |
filteredTape.push(prunedNode); | |
} | |
} | |
return filteredTape; | |
} | |
/** | |
* Backpropagate gradients through the filtered TapeNodes. | |
* | |
* @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map | |
* is mutated by this method. | |
* @param filteredTape The filtered TapeNodes to backprop through. | |
*/ | |
function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy) { | |
// Walk the tape backward and keep a map of Tensor to its gradient. | |
for (let i = filteredTape.length - 1; i >= 0; i--) { | |
const node = filteredTape[i]; | |
const dys = []; | |
node.outputs.forEach(o => { | |
const gradTensor = tensorAccumulatedGradientMap[o.id]; | |
if (gradTensor != null) { | |
dys.push(gradTensor); | |
} | |
else { | |
// This particular output is not in the back-propagation subgraph, so it | |
// does not affect the final output, thus we put null for its dy. | |
dys.push(null); | |
} | |
}); | |
if (node.gradient == null) { | |
throw new Error(`Cannot compute gradient: gradient function not found ` + | |
`for ${node.kernelName}.`); | |
} | |
// Backprop dy through this node and accumulate gradients over the inputs. | |
const inputGradients = node.gradient(dys); | |
for (const inputName in node.inputs) { | |
if (!(inputName in inputGradients)) { | |
throw new Error(`Cannot backprop through input ${inputName}. ` + | |
`Available gradients found: ${Object.keys(inputGradients)}.`); | |
} | |
// Call the gradient function. | |
const dx = tidy(() => inputGradients[inputName]()); | |
if (dx.dtype !== 'float32') { | |
throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` + | |
`${inputName} must have 'float32' dtype, but has '${dx.dtype}'`); | |
} | |
const x = node.inputs[inputName]; | |
if (!arraysEqual(dx.shape, x.shape)) { | |
throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` + | |
`'${inputName}' has shape '${dx.shape}', which does not match ` + | |
`the shape of the input '${x.shape}'`); | |
} | |
if (tensorAccumulatedGradientMap[x.id] == null) { | |
tensorAccumulatedGradientMap[x.id] = dx; | |
} | |
else { | |
const curGradient = tensorAccumulatedGradientMap[x.id]; | |
tensorAccumulatedGradientMap[x.id] = curGradient.add(dx); | |
curGradient.dispose(); | |
} | |
} | |
} | |
} | |
/** | |
* @license | |
* Copyright 2018 Google LLC. All Rights Reserved. | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
* ============================================================================= | |
*/ | |
// Maximum number of values before we decide to show ellipsis. | |
const FORMAT_LIMIT_NUM_VALS = 20; | |
// Number of first and last values to show when displaying a, b,...,y, z. | |
const FORMAT_NUM_FIRST_LAST_VALS = 3; | |
// Number of significant digits to show. | |
const FORMAT_NUM_SIG_DIGITS = 7; | |
function tensorToString(vals, shape, dtype, verbose) { | |
const strides = computeStrides(shape); | |
const padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides); | |
const rank = shape.length; | |
const valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol); | |
const lines = ['Tensor']; | |
if (verbose) { | |
lines.push(` dtype: ${dtype}`); | |
lines.push(` rank: ${rank}`); | |
lines.push(` shape: [${shape}]`); | |
lines.push(` values:`); | |
} | |
lines.push(valsLines.map(l => ' ' + l).join('\n')); | |
return lines.join('\n'); | |
} | |
function computeMaxSizePerColumn(vals, shape, dtype, strides) { | |
const n = sizeFromShape(shape); | |
const numCols = strides[strides.length - 1]; | |
const padPerCol = new Array(numCols).fill(0); | |
const rank = shape.length; | |
const valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals; | |
if (rank > 1) { | |
for (let row = 0; row < n / numCols; row++) { | |
const offset = row * numCols; | |
for (let j = 0; j < numCols; j++) { | |
padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length); | |
} | |
} | |
} | |
return padPerCol; | |
} | |
function valToString(val, pad, dtype) { | |
let valStr; | |
if (Array.isArray(val)) { | |
valStr = `${parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS))} + ` + | |
`${parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS))}j`; | |
} | |
else if (isString(val)) { | |
valStr = `'${val}'`; | |
} | |
else if (dtype === 'bool') { | |
valStr = boolNumToString(val); | |
} | |
else { | |
valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString(); | |
} | |
return rightPad(valStr, pad); | |
} | |
function boolNumToString(v) { | |
return v === 0 ? 'false' : 'true'; | |
} | |
function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast = true) { | |
const storagePerElement = dtype === 'complex64' ? 2 : 1; | |
const size = shape[0]; | |
const rank = shape.length; | |
if (rank === 0) { | |
if (dtype === 'complex64') { | |
const complexTuple = createComplexTuples(vals); | |
return [valToString(complexTuple[0], 0, dtype)]; | |
} | |
if (dtype === 'bool') { | |
return [boolNumToString(vals[0])]; | |
} | |
return [vals[0].toString()]; | |
} | |
if (rank === 1) { | |
if (size > FORMAT_LIMIT_NUM_VALS) { | |
const firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement; | |
let firstVals = Array.from(vals.slice(0, firstValsSize)); | |
let lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement)); | |
if (dtype === 'complex64') { | |
firstVals = createComplexTuples(firstVals); | |
lastVals = createComplexTuples(lastVals); | |
} | |
return [ | |
'[' + | |
firstVals.map((x, i) => valToString(x, padPerCol[i], dtype)) | |
.join(', ') + | |
', ..., ' + | |
lastVals | |
.map((x, i) => valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype)) | |
.join(', ') + | |
']' | |
]; | |
} | |
const displayVals = dtype === 'complex64' ? createComplexTuples(vals) : | |
Array.from(vals); | |
return [ | |
'[' + | |
displayVals.map((x, i) => valToString(x, padPerCol[i], dtype)) | |
.join(', ') + | |
']' | |
]; | |
} | |
// The array is rank 2 or more. | |
const subshape = shape.slice(1); | |
const substrides = strides.slice(1); | |
const stride = strides[0] * storagePerElement; | |
const lines = []; | |
if (size > FORMAT_LIMIT_NUM_VALS) { | |
for (let i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) { | |
const start = i * stride; | |
const end = start + stride; | |
lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false /* isLast */)); | |
} | |
lines.push('...'); | |
for (let i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) { | |
const start = i * stride; | |
const end = start + stride; | |
lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */)); | |
} | |
} | |
else { | |
for (let i = 0; i < size; i++) { | |
const start = i * stride; | |
const end = start + stride; | |
lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */)); | |
} | |
} | |
const sep = rank === 2 ? ',' : ''; | |
lines[0] = '[' + lines[0] + sep; | |
for (let i = 1; i < lines.length - 1; i++) { | |
lines[i] = ' ' + lines[i] + sep; | |
} | |
let newLineSep = ',\n'; | |
for (let i = 2; i < rank; i++) { | |
newLineSep += '\n'; | |
} | |
lines[lines.length - 1] = | |
' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep); | |
return lines; | |
} | |
function createComplexTuples(vals) { | |
const complexTuples = []; | |
for (let i = 0; i < vals.length; i += 2) { | |
complexTuples.push([vals[i], vals[i + 1]]); | |
} | |
return complexTuples; | |
} | |
/** | |
* @license | |
* Copyright 2017 Google LLC. All Rights Reserved. | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
* ============================================================================= | |
*/ | |
/** | |
* A mutable object, similar to `tf.Tensor`, that allows users to set values | |
* at locations before converting to an immutable `tf.Tensor`. | |
* | |
* See `tf.buffer` for creating a tensor buffer. | |
*/ | |
/** @doc {heading: 'Tensors', subheading: 'Classes'} */ | |
class TensorBuffer { | |
constructor(shape, dtype, values) { | |
this.dtype = dtype; | |
this.shape = shape.slice(); | |
this.size = sizeFromShape(shape); | |
if (values != null) { | |
const n = values.length; | |
assert(n === this.size, () => `Length of values '${n}' does not match the size ` + | |
`inferred by the shape '${this.size}'.`); | |
} | |
if (dtype === 'complex64') { | |
throw new Error(`complex64 dtype TensorBuffers are not supported. Please create ` + | |
`a TensorBuffer for the real and imaginary parts separately and ` + | |
`call tf.complex(real, imag).`); | |
} | |
this.values = values || getArrayFromDType(dtype, this.size); | |
this.strides = computeStrides(shape); | |
} | |
/** | |
* Sets a value in the buffer at a given location. | |
* | |
* @param value The value to set. | |
* @param locs The location indices. | |
*/ | |
/** @doc {heading: 'Tensors', subheading: 'Creation'} */ | |
set(value, ...locs) { | |
if (locs.length === 0) { | |
locs = [0]; | |
} | |
assert(locs.length === this.rank, () => `The number of provided coordinates (${locs.length}) must ` + | |
`match the rank (${this.rank})`); | |
const index = this.locToIndex(locs); | |
this.values[index] = value; | |
} | |
/** | |
* Returns the value in the buffer at the provided location. | |
* | |
* @param locs The location indices. | |
*/ | |
/** @doc {heading: 'Tensors', subheading: 'Creation'} */ | |
get(...locs) { | |
if (locs.length === 0) { | |
locs = [0]; | |
} | |
let i = 0; | |
for (const loc of locs) { | |
if (loc < 0 || loc >= this.shape[i]) { | |
const msg = `Requested out of range element at ${locs}. ` + | |
` Buffer shape=${this.shape}`; | |
throw new Error(msg); | |
} | |
i++; | |
} | |
let index = locs[locs.length - 1]; | |
for (let i = 0; i < locs.length - 1; ++i) { | |
index += this.strides[i] * locs[i]; | |
} | |
return this.values[index]; | |
} | |
locToIndex(locs) { | |
if (this.rank === 0) { | |
return 0; | |
} | |
else if (this.rank === 1) { | |
return locs[0]; | |
} | |
let index = locs[locs.length - 1]; | |
for (let i = 0; i < locs.length - 1; ++i) { | |
index += this.strides[i] * locs[i]; | |
} | |
return index; | |
} | |
indexToLoc(index) { | |
if (this.rank === 0) { | |
return []; | |
} | |
else if (this.rank === 1) { | |
return [index]; | |
} | |
const locs = new Array(this.shape.length); | |
for (let i = 0; i < locs.length - 1; ++i) { | |
locs[i] = Math.floor(index / this.strides[i]); | |
index -= locs[i] * this.strides[i]; | |
} | |
locs[locs.length - 1] = index; | |
return locs; | |
} | |
get rank() { | |
return this.shape.length; | |
} | |
/** | |
* Creates an immutable `tf.Tensor` object from the buffer. | |
*/ | |
/** @doc {heading: 'Tensors', subheading: 'Creation'} */ | |
toTensor() { | |
return trackerFn().makeTensor(this.values, this.shape, this.dtype); | |
} | |
} | |
// For tracking tensor creation and disposal. | |
let trackerFn = null; | |
// Used by chaining methods to call into ops. | |
let opHandler = null; | |
/** | |
* An external consumer can register itself as the tensor tracker. This way | |
* the Tensor class can notify the tracker for every tensor created and | |
* disposed. | |
*/ | |
function setTensorTracker(fn) { | |
trackerFn = fn; | |
} | |
/** | |
* An external consumer can register itself as the op handler. This way the | |
* Tensor class can have chaining methods that call into ops via the op | |
* handler. | |
*/ | |
function setOpHandler(handler) { | |
opHandler = handler; | |
} | |
/** | |
* A `tf.Tensor` object represents an immutable, multidimensional array of | |
* numbers that has a shape and a data type. | |
* | |
* See `tf.tensor` for details on how to create a `tf.Tensor`. | |
*/ | |
/** @doc {heading: 'Tensors', subheading: 'Classes'} */ | |
class Tensor { | |
constructor(shape, dtype, dataId, id) { | |
/** Whether this tensor has been globally kept. */ | |
this.kept = false; | |
this.isDisposedInternal = false; | |
this.shape = shape.slice(); | |
this.dtype = dtype || 'float32'; | |
this.size = sizeFromShape(shape); | |
this.strides = computeStrides(shape); | |
this.dataId = dataId; | |
this.id = id; | |
this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher'); | |
} | |
get rank() { | |
return this.shape.length; | |
} | |
/** | |
* Returns a promise of `tf.TensorBuffer` that holds the underlying data. | |
*/ | |
/** @doc {heading: 'Tensors', subheading: 'Classes'} */ | |
async buffer() { | |
const vals = await this.data(); | |
return opHandler.buffer(this.shape, this.dtype, vals); | |
} | |
/** Returns a `tf.TensorBuffer` that holds the underlying data. */ | |
/** @doc {heading: 'Tensors', subheading: 'Classes'} */ | |
bufferSync() { | |
return opHandler.buffer(this.shape, this.dtype, this.dataSync()); | |
} | |
/** | |
* Returns the tensor data as a nested array. The transfer of data is done | |
* asynchronously. | |
*/ | |
/** @doc {heading: 'Tensors', subheading: 'Classes'} */ | |
async array() { | |
const vals = await this.data(); | |
return toNestedArray(this.shape, vals); | |
} | |
/** | |
* Returns the tensor data as a nested array. The transfer of data is done | |
* synchronously. | |
*/ | |
/** @doc {heading: 'Tensors', subheading: 'Classes'} */ | |
arraySync() { | |
return toNestedArray(this.shape, this.dataSync()); | |
} | |
/** | |
* Asynchronously downloads the values from the `tf.Tensor`. Returns a | |
* promise of `TypedArray` that resolves when the computation has finished. | |
*/ | |
/** @doc {heading: 'Tensors', subheading: 'Classes'} */ | |
async data() { | |
this.throwIfDisposed(); | |
const data = trackerFn().read(this.dataId); | |
if (this.dtype === 'string') { | |
const bytes = await data; | |
try { | |
return bytes.map(b => decodeString(b)); | |
} | |
catch (_a) { | |
throw new Error('Failed to decode the string bytes into utf-8. ' + | |
'To get the original bytes, call tensor.bytes().'); | |
} | |
} | |
return data; | |
} | |
/** | |
* Synchronously downloads the values from the `tf.Tensor`. This blocks the | |
* UI thread until the values are ready, which can cause performance issues. | |
*/ | |
/** @doc {heading: 'Tensors', subheading: 'Classes'} */ | |
dataSync() { | |
this.throwIfDisposed(); | |
const data = trackerFn().readSync(this.dataId); | |
if (this.dtype === 'string') { | |
try { | |
return data.map(b => decodeString(b)); | |
} | |
catch (_a) { | |
throw new Error('Failed to decode the string bytes into utf-8. ' + | |
'To get the original bytes, call tensor.bytes().'); | |
} | |
} | |
return data; | |
} | |
/** Returns the underlying bytes of the tensor's data. */ | |
async bytes() { | |
this.throwIfDisposed(); | |
const data = await trackerFn().read(this.dataId); | |
if (this.dtype === 'string') { | |
return data; | |
} | |
else { | |
return new Uint8Array(data.buffer); | |
} | |
} | |
/** | |
* Disposes `tf.Tensor` from memory. | |
*/ | |
/** @doc {heading: 'Tensors', subheading: 'Classes'} */ | |
dispose() { | |
if (this.isDisposed) { | |
return; | |
} | |
trackerFn().disposeTensor(this); | |
this.isDisposedInternal = true; | |
} | |
get isDisposed() { | |
return this.isDisposedInternal; | |
} | |
throwIfDisposed() { | |
if (this.isDisposed) { | |
throw new Error(`Tensor is disposed.`); | |
} | |
} | |
/** | |
* Prints the `tf.Tensor`. See `tf.print` for details. | |
* | |
* @param verbose Whether to print verbose information about the tensor, | |
* including dtype and size. | |
*/ | |
/** @doc {heading: 'Tensors', subheading: 'Classes'} */ | |
print(verbose = false) { | |
return opHandler.print(this, verbose); | |
} | |
/** Returns a copy of the tensor. See `tf.clone` for details. */ | |
/** @doc {heading: 'Tensors', subheading: 'Classes'} */ | |
clone() { | |
this.throwIfDisposed(); | |
return opHandler.clone(this); | |
} | |
/** | |
* Returns a human-readable description of the tensor. Useful for logging. | |
*/ | |
/** @doc {heading: 'Tensors', subheading: 'Classes'} */ | |
toString(verbose = false) { | |
const vals = this.dataSync(); | |
return tensorToString(vals, this.shape, this.dtype, verbose); | |
} | |
cast(dtype) { | |
this.throwIfDisposed(); | |
return opHandler.cast(this, dtype); | |
} | |
variable(trainable = true, name, dtype) { | |
this.throwIfDisposed(); | |
return trackerFn().makeVariable(this, trainable, name, dtype); | |
} | |
} | |
Object.defineProperty(Tensor, Symbol.hasInstance, { | |
value: (instance) => { | |
return !!instance && instance.dataId != null && instance.shape != null && | |
instance.dtype != null; | |
} | |
}); | |
/** | |
* A mutable `tf.Tensor`, useful for persisting state, e.g. for training. | |
*/ | |
/** @doc {heading: 'Tensors', subheading: 'Classes'} */ | |
class Variable extends Tensor { | |
constructor(initialValue, trainable, name, tensorId) { | |
super(initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId); | |
this.trainable = trainable; | |
this.name = name; | |
} | |
/** | |
* Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have | |
* the same shape and dtype as the old `tf.Tensor`. | |
* | |
* @param newValue New tensor to be assigned to this variable. | |
*/ | |
/** @doc {heading: 'Tensors', subheading: 'Classes'} */ | |
assign(newValue) { | |
if (newValue.dtype !== this.dtype) { | |
throw new Error(`dtype of the new value (${newValue.dtype}) and ` + | |
`previous value (${this.dtype}) must match`); | |
} | |
if (!arraysEqual(newValue.shape, this.shape)) { | |
throw new Error(`shape of the new value (${newValue.shape}) and ` + | |
`previous value (${this.shape}) must match`); | |
} | |
trackerFn().disposeTensor(this); | |
this.dataId = newValue.dataId; | |
trackerFn().incRef(this, null /* backend */); | |
} | |
dispose() { | |
trackerFn().disposeVariable(this); | |
this.isDisposedInternal = true; | |
} | |
} | |
Object.defineProperty(Variable, Symbol.hasInstance, { | |
value: (instance) => { | |
return instance instanceof Tensor && instance.assign != null && | |
instance.assign instanceof Function; | |
} | |
}); | |
/** | |
* @license | |
* Copyright 2017 Google LLC. All Rights Reserved. | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
* ============================================================================= | |
*/ | |
var Rank; | |
(function (Rank) { | |
Rank["R0"] = "R0"; | |
Rank["R1"] = "R1"; | |
Rank["R2"] = "R2"; | |
Rank["R3"] = "R3"; | |
Rank["R4"] = "R4"; | |
Rank["R5"] = "R5"; | |
Rank["R6"] = "R6"; | |
})(Rank || (Rank = {})); | |
// Looks for upcasting types. Used, for example, in operations with mixed dtype | |
// inputs. | |
var UpcastInt32AndMap; | |
(function (UpcastInt32AndMap) { | |
UpcastInt32AndMap["float32"] = "float32"; | |
UpcastInt32AndMap["int32"] = "int32"; | |
UpcastInt32AndMap["bool"] = "int32"; | |
UpcastInt32AndMap["complex64"] = "complex64"; | |
})(UpcastInt32AndMap || (UpcastInt32AndMap = {})); | |
var UpcastBoolAndMap; | |
(function (UpcastBoolAndMap) { | |
UpcastBoolAndMap["float32"] = "float32"; | |
UpcastBoolAndMap["int32"] = "int32"; | |
UpcastBoolAndMap["bool"] = "bool"; | |
UpcastBoolAndMap["complex64"] = "complex64"; | |
})(UpcastBoolAndMap || (UpcastBoolAndMap = {})); | |
var UpcastFloat32AndMap; | |
(function (UpcastFloat32AndMap) { | |
UpcastFloat32AndMap["float32"] = "float32"; | |
UpcastFloat32AndMap["int32"] = "float32"; | |
UpcastFloat32AndMap["bool"] = "float32"; | |
UpcastFloat32AndMap["complex64"] = "complex64"; | |
})(UpcastFloat32AndMap || (UpcastFloat32AndMap = {})); | |
var UpcastComplex64AndMap; | |
(function (UpcastComplex64AndMap) { | |
UpcastComplex64AndMap["float32"] = "complex64"; | |
UpcastComplex64AndMap["int32"] = "complex64"; | |
UpcastComplex64AndMap["bool"] = "complex64"; | |
UpcastComplex64AndMap["complex64"] = "complex64"; | |
})(UpcastComplex64AndMap || (UpcastComplex64AndMap = {})); | |
const upcastTypeMap = { | |
'float32': UpcastFloat32AndMap, | |
'int32': UpcastInt32AndMap, | |
'bool': UpcastBoolAndMap, | |
'complex64': UpcastComplex64AndMap | |
}; | |
function upcastType(typeA, typeB) { | |
if (typeA === 'string' || typeB === 'string') { | |
if (typeA === 'string' && typeB === 'string') { | |
return 'string'; | |
} | |
throw new Error(`Can not upcast ${typeA} with ${typeB}`); | |
} | |
return upcastTypeMap[typeA][typeB]; | |
} | |
/** Returns the output type after summation. */ | |
function sumOutType(type) { | |
return upcastType(type, 'int32'); | |
} | |
/** | |
* @license | |
* Copyright 2018 Google LLC. All Rights Reserved. | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
* ============================================================================= | |
*/ | |
function makeTypesMatch(a, b) { | |
if (a.dtype === b.dtype) { | |
return [a, b]; | |
} | |
const dtype = upcastType(a.dtype, b.dtype); | |
return [a.cast(dtype), b.cast(dtype)]; | |
} | |
function assertTypesMatch(a, b) { | |
assert(a.dtype === b.dtype, () => `The dtypes of the first(${a.dtype}) and` + | |
` second(${b.dtype}) input must match`); | |
} | |
function isTensorInList(tensor, tensorList) { | |
return tensorList.some(x => x.id === tensor.id); | |
} | |
/** | |
* Extracts any `Tensor`s found within the provided object. | |
* | |
* @param container an object that may be a `Tensor` or may directly contain | |
* `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it | |
* is safe to pass any object here, except that `Promise`s are not | |
* supported. | |
* @returns An array of `Tensors` found within the passed object. If the | |
* argument is simply a `Tensor', a list containing that `Tensor` is | |
* returned. If the object is not a `Tensor` or does not | |
* contain `Tensors`, an empty list is returned. | |
*/ | |
function getTensorsInContainer(result) { | |
const list = []; | |
const seen = new Set(); | |
walkTensorContainer(result, list, seen); | |
return list; | |
} | |
function walkTensorContainer(container, list, seen) { | |
if (container == null) { | |
return; | |
} | |
if (container instanceof Tensor) { | |
list.push(container); | |
return; | |
} | |
if (!isIterable(container)) { | |
return; | |
} | |
// Iteration over keys works also for arrays. | |
const iterable = container; | |
for (const k in iterable) { | |
const val = iterable[k]; | |
if (!seen.has(val)) { | |
seen.add(val); | |
walkTensorContainer(val, list, seen); | |
} | |
} | |
} | |
// tslint:disable-next-line:no-any | |
function isIterable(obj) { | |
return Array.isArray(obj) || typeof obj === 'object'; | |
} | |
var tensor_util = { | |
__proto__: null, | |
makeTypesMatch: makeTypesMatch, | |
assertTypesMatch: assertTypesMatch, | |
isTensorInList: isTensorInList, | |
getTensorsInContainer: getTensorsInContainer | |
}; | |
/** | |
* @license | |
* Copyright 2018 Google LLC. All Rights Reserved. | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
* ============================================================================= | |
*/ | |
class EngineState { | |
constructor() { | |
// Public since optimizers will use it. | |
this.registeredVariables = {}; | |
this.nextTapeNodeId = 0; | |
this.numBytes = 0; | |
this.numTensors = 0; | |
this.numStringTensors = 0; | |
this.numDataBuffers = 0; | |
// Number of nested tf.grad() statements when computing higher-order | |
// gradients. E.g. `1` for first-order gradients and `2` for second-order | |
// gradients. Used to track if the tape should be removed after a backprop. | |
this.gradientDepth = 0; | |
// Number of nested kernel calls. When kernel depth is greater than 1, we turn | |
// off the tape. | |
this.kernelDepth = 0; | |
this.scopeStack = []; | |
/** | |
* Keeps track of the number of data moves during a kernel execution. We | |
* maintain a stack since kernels can call other kernels, recursively. | |
*/ | |
this.numDataMovesStack = []; | |
this.nextScopeId = 0; | |
this.tensorInfo = new WeakMap(); | |
this.profiling = false; | |
this.activeProfile = { newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null }; | |
} | |
dispose() { | |
for (const variableName in this.registeredVariables) { | |
this.registeredVariables[variableName].dispose(); | |
} | |
} | |
} | |
class Engine { | |
constructor(ENV) { | |
this.ENV = ENV; | |
this.registry = {}; | |
this.registryFactory = {}; | |
this.pendingBackendInitId = 0; | |
this.state = new EngineState(); | |
} | |
async ready() { | |
if (this.pendingBackendInit != null) { | |
return this.pendingBackendInit.then(() => { }); | |
} | |
if (this.backendInstance != null) { | |
return; | |
} | |
const sortedBackends = this.getSortedBackends(); | |
for (let i = 0; i < sortedBackends.length; i++) { | |
const backendName = sortedBackends[i]; | |
const success = await this.initializeBackend(backendName).success; | |
if (success) { | |
await this.setBackend(backendName); | |
return; | |
} | |
} | |
throw new Error(`Could not initialize any backends, all backend initializations ` + | |
`failed.`); | |
} | |
get backend() { | |
if (this.pendingBackendInit != null) { | |
throw new Error(`Backend '${this.backendName}' has not yet been initialized. Make ` + | |
`sure to await tf.ready() or await tf.setBackend() before calling ` + | |
`other methods`); | |
} | |
if (this.backendInstance == null) { | |
const { name, asyncInit } = this.initializeBackendsAndReturnBest(); | |
if (asyncInit) { | |
throw new Error(`The highest priority backend '${name}' has not yet been ` + | |
`initialized. Make sure to await tf.ready() or ` + | |
`await tf.setBackend() before calling other methods`); | |
} | |
this.setBackend(name); | |
} | |
return this.backendInstance; | |
} | |
backendNames() { | |
return Object.keys(this.registryFactory); | |
} | |
findBackend(backendName) { | |
if (!(backendName in this.registry)) { | |
// If the backend hasn't been initialized but we have a registry entry for | |
// it, initialize it and return it. | |
if (backendName in this.registryFactory) { | |
const { asyncInit } = this.initializeBackend(backendName); | |
if (asyncInit) { | |
// Backend is not ready yet. | |
return null; | |
} | |
} | |
else { | |
return null; | |
} | |
} | |
return this.registry[backendName]; | |
} | |
findBackendFactory(backendName) { | |
if (!(backendName in this.registryFactory)) { | |
return null; | |
} | |
return this.registryFactory[backendName].factory; | |
} | |
registerBackend(backendName, factory, priority = 1) { | |
if (backendName in this.registryFactory) { | |
console.warn(`${backendName} backend was already registered. ` + | |
`Reusing existing backend factory.`); | |
return false; | |
} | |
this.registryFactory[backendName] = { factory, priority }; | |
return true; | |
} | |
async setBackend(backendName) { | |
if (this.registryFactory[backendName] == null) { | |
throw new Error(`Backend name '${backendName}' not found in registry`); | |
} | |
this.backendName = backendName; | |
if (this.registry[backendName] == null) { | |
this.backendInstance = null; | |
const { success, asyncInit } = this.initializeBackend(backendName); | |
const result = asyncInit ? await success : success; | |
if (!result) { | |
return false; | |
} | |
} | |
this.backendInstance = this.registry[backendName]; | |
this.setupRegisteredKernels(); | |
// Reset the profiler. | |
this.profiler = new Profiler(this.backendInstance); | |
return true; | |
} | |
setupRegisteredKernels() { | |
const kernels = getKernelsForBackend(this.backendName); | |
kernels.forEach(kernel => { | |
if (kernel.setupFunc != null) { | |
kernel.setupFunc(this.backendInstance); | |
} | |
}); | |
} | |
disposeRegisteredKernels(backendName) { | |
const kernels = getKernelsForBackend(backendName); | |
kernels.forEach(kernel => { | |
if (kernel.disposeFunc != null) { | |
kernel.disposeFunc(this.registry[backendName]); | |
} | |
}); | |
} | |
/** | |
* Initializes a backend by looking up the backend name in the factory | |
* registry and calling the factory method. Returns a boolean representing | |
* whether the initialization of the backend suceeded. Throws an error if | |
* there is no backend in the factory registry. | |
*/ | |
initializeBackend(backendName) { | |
const registryFactoryEntry = this.registryFactory[backendName]; | |
if (registryFactoryEntry == null) { | |
throw new Error(`Cannot initialize backend ${backendName}, no registration found.`); | |
} | |
try { | |
const backend = registryFactoryEntry.factory(); | |
// Test if the factory returns a promise. | |
if (Promise.resolve(backend) === backend) { | |
const promiseId = ++this.pendingBackendInitId; | |
const success = backend | |
.then(backendInstance => { | |
// Outdated promise. Another backend was set in the meantime. | |
if (promiseId < this.pendingBackendInitId) { | |
return false; | |
} | |
this.registry[backendName] = backendInstance; | |
this.pendingBackendInit = null; | |
return true; | |
}) | |
.catch(err => { | |
// Outdated promise. Another backend was set in the meantime. | |
if (promiseId < this.pendingBackendInitId) { | |
return false; | |
} | |
this.pendingBackendInit = null; | |
console.warn(`Initialization of backend ${backendName} failed`); | |
console.warn(err.stack || err.message); | |
return false; | |
}); | |
this.pendingBackendInit = success; | |
return { success, asyncInit: true }; | |
} | |
else { | |
this.registry[backendName] = backend; | |
return { success: true, asyncInit: false }; | |
} | |
} | |
catch (err) { | |
console.warn(`Initialization of backend ${backendName} failed`); | |
console.warn(err.stack || err.message); | |
return { success: false, asyncInit: false }; | |
} | |
} | |
removeBackend(backendName) { | |
if (!(backendName in this.registryFactory)) { | |
throw new Error(`${backendName} backend not found in registry`); | |
} | |
if (this.backendName === backendName && this.pendingBackendInit != null) { | |
// There is a pending promise of the backend we want to remove. Make it | |
// obsolete. | |
this.pendingBackendInitId++; | |
} | |
if (backendName in this.registry) { | |
this.disposeRegisteredKernels(backendName); | |
this.registry[backendName].dispose(); | |
delete this.registry[backendName]; | |
} | |
delete this.registryFactory[backendName]; | |
// Unset the backend if it is active. | |
if (this.backendName === backendName) { | |
this.pendingBackendInit = null; | |
this.backendName = null; | |
this.backendInstance = null; | |
} | |
} | |
getSortedBackends() { | |
if (Object.keys(this.registryFactory).length === 0) { | |
throw new Error('No backend found in registry.'); | |
} | |
return Object.keys(this.registryFactory).sort((a, b) => { | |
// Highest priority comes first. | |
return this.registryFactory[b].priority - | |
this.registryFactory[a].priority; | |
}); | |
} | |
initializeBackendsAndReturnBest() { | |
const sortedBackends = this.getSortedBackends(); | |
for (let i = 0; i < sortedBackends.length; i++) { | |
const backendName = sortedBackends[i]; | |
const { success, asyncInit } = this.initializeBackend(backendName); | |
if (asyncInit || success) { | |
return { name: backendName, asyncInit }; | |
} | |
} | |
throw new Error(`Could not initialize any backends, all backend initializations ` + | |
`failed.`); | |
} | |
moveData(backend, dataId) { | |
const info = this.state.tensorInfo.get(dataId); | |
const srcBackend = info.backend; | |
const values = this.readSync(dataId); | |
// Delete the tensor from the old backend and move it to the new | |
// backend. | |
srcBackend.disposeData(dataId); | |
info.backend = backend; | |
backend.move(dataId, values, info.shape, info.dtype); | |
if (this.shouldCheckForMemLeaks()) { | |
// Track the number of moves during a kernel execution to correctly | |
// detect memory leaks. | |
this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++; | |
} | |
} | |
tidy(nameOrFn, fn) { | |
let name = null; | |
if (fn == null) { | |
// Called with only 1 argument. | |
if (typeof nameOrFn !== 'function') { | |
throw new Error('Please provide a function to tidy()'); | |
} | |
fn = nameOrFn; | |
} | |
else { | |
// Called with 2 arguments. | |
if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) { | |
throw new Error('When calling with two arguments, the first argument ' + | |
'to tidy() must be a string'); | |
} | |
if (typeof fn !== 'function') { | |
throw new Error('When calling with two arguments, the 2nd argument ' + | |
'to tidy() must be a function'); | |
} | |
name = nameOrFn; | |
// TODO(nsthorat,smilkov): Do operation logging and performance | |
// profiling. | |
} | |
let result; | |
return this.scopedRun(() => this.startScope(name), () => this.endScope(result), () => { | |
result = fn(); | |
if (result instanceof Promise) { | |
console.error('Cannot return a Promise inside of tidy.'); | |
} | |
return result; | |
}); | |
} | |
scopedRun(start, end, f) { | |
start(); | |
try { | |
const res = f(); | |
end(); | |
return res; | |
} | |
catch (ex) { | |
end(); | |
throw ex; | |
} | |
} | |
nextTensorId() { | |
return Engine.nextTensorId++; | |
} | |
nextVariableId() { | |
return Engine.nextVariableId++; | |
} | |
/** | |
* This method is called instead of the public-facing tensor.clone() when | |
* saving a tensor for backwards pass. It makes sure to add the clone | |
* operation to the tape regardless of being called inside a kernel | |
* execution. | |
* | |
* This method will go away once all kernels are modularized since we won't | |
* need to turn off the tape inside runKernel(). | |
*/ | |
clone(x) { | |
const y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype); | |
const inputs = { x }; | |
const grad = (dy) => ({ x: () => dy.toFloat() }); | |
const saved = []; | |
this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {}); | |
return y; | |
} | |
/** | |
* Execute a kernel with the given name and return the output tensor. | |
* | |
* @param kernelName The name of the kernel to execute. | |
* @param inputs A map of input names to tensors. | |
* @param attrs A map of attribute names to their values. An attribute is a | |
* primitive (non-tensor) input to the kernel. | |
* @param inputsToSave A list of tensors, inputs to save for the backprop | |
* computation. | |
* @param outputsToSave A list of booleans, specifying which output to save | |
* for the backprop computation. These are booleans since the output | |
* tensors are not visible to the user. | |
*/ | |
runKernel(kernelName, inputs, attrs, inputsToSave, outputsToSave) { | |
const forwardFunc = null; | |
const backwardsFunc = null; | |
// Call runKernel as a stop-gap until we modularize all kernels. | |
// Once we modularize all kernels, we will remove the existing | |
// `runKernelFunc`. | |
return this.runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave); | |
} | |
shouldCheckForMemLeaks() { | |
return this.ENV.getBool('IS_TEST'); | |
} | |
checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos) { | |
const numDataIdsAfter = this.backend.numDataIds(); | |
// Count the number of data ids associated with the result of the kernel. | |
let numOutputDataIds = 0; | |
outInfos.forEach(info => { | |
// Complex numbers allocate 3 data ids, one for 'real', one for | |
// 'imaginary', and one for the container that holds the former two. | |
numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1); | |
}); | |
// Account for the number of moves during kernel execution. A "data move" | |
// can happen in the middle of a kernel execution, placing a new (key,value) | |
// pair in the data storage. Since data moves have net zero effect (we | |
// always remove the data from the old backend), we have to cancel them out | |
// when detecting memory leaks. | |
const numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]; | |
const dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves; | |
if (dataIdsLeaked > 0) { | |
throw new Error(`Backend '${this.backendName}' has an internal memory leak ` + | |
`(${dataIdsLeaked} data ids) after running '${kernelName}'`); | |
} | |
} | |
/** | |
* @deprecated Use `runKernel` for newly added kernels. Keep using this method | |
* only for kernels that are not yet fully modularized. | |
*/ | |
runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave) { | |
let outputs; | |
let saved = []; | |
const isTapeOn = this.isTapeOn(); | |
if (kernelName == null) { | |
kernelName = | |
this.state.activeScope != null ? this.state.activeScope.name : ''; | |
} | |
const startingBytecount = this.state.numBytes; | |
const startingNumTensors = this.state.numTensors; | |
if (this.shouldCheckForMemLeaks()) { | |
this.state.numDataMovesStack.push(0); | |
} | |
let kernelFunc; | |
const kernel = getKernel(kernelName, this.backendName); | |
let out; | |
if (kernel != null) { | |
kernelFunc = () => { | |
const numDataIdsBefore = this.backend.numDataIds(); | |
out = kernel.kernelFunc({ inputs, attrs, backend: this.backend }); | |
const outInfos = Array.isArray(out) ? out : [out]; | |
if (this.shouldCheckForMemLeaks()) { | |
this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos); | |
} | |
const outTensors = outInfos.map(({ dataId, shape, dtype }) => this.makeTensorFromDataId(dataId, shape, dtype)); | |
// Save the inputs and outputs. | |
// Do not save unless we are recording to the tape. Otherwise it would | |
// cause a mem leak since we would never run backprop, which disposes | |
// the kept tensors. | |
if (isTapeOn) { | |
let tensorsToSave = this.getTensorsForGradient(kernelName, inputs, outTensors); | |
if (tensorsToSave == null) { | |
// Fallback for ops that call runKernelFunc and pass in | |
// inputsToSave and outputsToSave. Currently this is the set of ops | |
// with kernel support in the WASM backend. Once those ops and | |
// respective gradients are modularised we can remove this path. | |
if (outputsToSave == null) { | |
outputsToSave = []; | |
} | |
const outsToSave = outTensors.filter((_, i) => outputsToSave[i]); | |
tensorsToSave = (inputsToSave || []).slice().concat(outsToSave); | |
} | |
saved = this.saveTensorsForBackwardMode(tensorsToSave); | |
} | |
return outTensors; | |
}; | |
} | |
else { | |
const saveFunc = (tensors) => { | |
// Do not save unless we are recording to the tape. Otherwise it would | |
// cause a mem leak since we would never run backprop, which disposes | |
// the kept tensors. | |
if (!isTapeOn) { | |
return; | |
} | |
saved = tensors.map(tensor => this.keep(this.clone(tensor))); | |
}; | |
kernelFunc = () => { | |
const numDataIdsBefore = this.backend.numDataIds(); | |
out = this.tidy(() => forwardFunc(this.backend, saveFunc)); | |
const outs = (Array.isArray(out) ? out : [out]); | |
if (this.shouldCheckForMemLeaks()) { | |
this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outs); | |
} | |
return outs; | |
}; | |
} | |
// Stop recording to a tape when running a kernel. | |
this.scopedRun(() => this.state.kernelDepth++, () => this.state.kernelDepth--, () => { | |
if (!this.ENV.getBool('DEBUG')) { | |
outputs = kernelFunc(); | |
} | |
else { | |
outputs = this.profiler.profileKernel(kernelName, inputs, () => kernelFunc()); | |
} | |
}); | |
if (isTapeOn) { | |
this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved, attrs); | |
} | |
if (this.state.profiling) { | |
this.state.activeProfile.kernels.push({ | |
name: kernelName, | |
bytesAdded: this.state.numBytes - startingBytecount, | |
totalBytesSnapshot: this.state.numBytes, | |
tensorsAdded: this.state.numTensors - startingNumTensors, | |
totalTensorsSnapshot: this.state.numTensors, | |
inputShapes: Object.keys(inputs).map(key => inputs[key].shape), | |
outputShapes: outputs.map(item => item.shape) | |
}); | |
} | |
return (Array.isArray(out) ? outputs : outputs[0]); | |
} | |
/** | |
* Saves tensors used in forward mode for use in backward mode. | |
* | |
* @param tensors the list of tensors to save. | |
*/ | |
saveTensorsForBackwardMode(tensors) { | |
const saved = tensors.map(tensor => this.keep(this.clone(tensor))); | |
return saved; | |
} | |
/** | |
* Returns a list of tensors to save for a given gradient calculation. | |
* | |
* Returns undefined if their is no registered gradient for this kernel in the | |
* gradient registry. | |
* | |
* @param kernelName name of kernel to look up gradient for. | |
* @param inputs a map of input tensors. | |
* @param outputs an array of output tensors from forward mode of kernel. | |
*/ | |
getTensorsForGradient(kernelName, inputs, outputs) { | |
const gradConfig = getGradient(kernelName); | |
if (gradConfig != null) { | |
const inputsToSave = gradConfig.inputsToSave || []; | |
const outputsToSave = gradConfig.outputsToSave || []; | |
// If saveAllInputs is true, all inputs will be saved. Otherwise, inputs | |
// specified in inputsToSave will be saved. | |
let inputTensorsToSave; | |
if (gradConfig.saveAllInputs) { | |
assert(Array.isArray(inputs), () => 'saveAllInputs is true, expected inputs to be an array.'); | |
inputTensorsToSave = Object.keys(inputs).map((key) => inputs[key]); | |
} | |
else { | |
inputTensorsToSave = inputsToSave.map((inputName) => inputs[inputName]); | |
} | |
const outputTensorsToSave = outputs.filter((_, i) => outputsToSave[i]); | |
return inputTensorsToSave.concat(outputTensorsToSave); | |
} | |
// TODO(yassogba) throw exception here once all runkernelFunc calls with | |
// inputsToSave/outputsToSave are removed | |
return null; | |
} | |
/** | |
* Internal method used by public APIs for tensor creation. Makes a new | |
* tensor with the provided shape, dtype and values. It always | |
* creates a new data id and writes the values to the underlying backend. | |
*/ | |
makeTensor(values, shape, dtype, backend) { | |
if (values == null) { | |
throw new Error('Values passed to engine.makeTensor() are null'); | |
} | |
dtype = dtype || 'float32'; | |
backend = backend || this.backend; | |
let backendVals = values; | |
if (dtype === 'string' && isString(values[0])) { | |
backendVals = values.map(d => encodeString(d)); | |
} | |
const dataId = backend.write(backendVals, shape, dtype); | |
const t = new Tensor(shape, dtype, dataId, this.nextTensorId()); | |
this.incRef(t, backend); | |
// Count bytes for string tensors. | |
if (dtype === 'string') { | |
const info = this.state.tensorInfo.get(dataId); | |
const newBytes = bytesFromStringArray(backendVals); | |
this.state.numBytes += newBytes - info.bytes; | |
info.bytes = newBytes; | |
} | |
return t; | |
} | |
/** | |
* Internal method used by backends. Makes a new tensor | |
* that is a wrapper around an existing data id. It doesn't create | |
* a new data id, only increments the ref count used in memory tracking. | |
*/ | |
makeTensorFromDataId(dataId, shape, dtype, backend) { | |
dtype = dtype || 'float32'; | |
const t = new Tensor(shape, dtype, dataId, this.nextTensorId()); | |
this.incRef(t, backend); | |
return t; | |
} | |
makeVariable(initialValue, trainable = true, name, dtype) { | |
name = name || this.nextVariableId().toString(); | |
if (dtype != null && dtype !== initialValue.dtype) { | |
initialValue = initialValue.cast(dtype); | |
} | |
const v = new Variable(initialValue, trainable, name, this.nextTensorId()); | |
if (this.state.registeredVariables[v.name] != null) { | |
throw new Error(`Variable with name ${v.name} was already registered`); | |
} | |
this.state.registeredVariables[v.name] = v; | |
this.incRef(v, this.backend); | |
return v; | |
} | |
incRef(a, backend) { | |
const refCount = this.state.tensorInfo.has(a.dataId) ? | |
this.state.tensorInfo.get(a.dataId).refCount : | |
0; | |
this.state.numTensors++; | |
if (a.dtype === 'string') { | |
this.state.numStringTensors++; | |
} | |
if (refCount === 0) { | |
this.state.numDataBuffers++; | |
// Bytes for complex numbers are counted by their components. Bytes for | |
// string tensors are counted when writing values. | |
let bytes = 0; | |
if (a.dtype !== 'complex64' && a.dtype !== 'string') { | |
bytes = a.size * bytesPerElement(a.dtype); | |
} | |
this.state.tensorInfo.set(a.dataId, { | |
backend: backend || this.backend, | |
dtype: a.dtype, | |
shape: a.shape, | |
bytes, | |
refCount: 0 | |
}); | |
this.state.numBytes += bytes; | |
} | |
this.state.tensorInfo.get(a.dataId).refCount++; | |
if (!(a instanceof Variable)) { | |
this.track(a); | |
} | |
} | |
disposeTensor(a) { | |
if (!this.state.tensorInfo.has(a.dataId)) { | |
return; | |
} | |
this.state.numTensors--; | |
if (a.dtype === 'string') { | |
this.state.numStringTensors--; | |
} | |
const info = this.state.tensorInfo.get(a.dataId); | |
const refCount = info.refCount; | |
if (refCount <= 1) { | |
// Don't count bytes for complex numbers as they are counted by their | |
// components. | |
if (a.dtype !== 'complex64') { | |
this.state.numBytes -= info.bytes; | |
} | |
this.state.numDataBuffers--; | |
info.backend.disposeData(a.dataId); | |
this.state.tensorInfo.delete(a.dataId); | |
} | |
else { | |
this.state.tensorInfo.get(a.dataId).refCount--; | |
} | |
// TODO(nsthorat): Construct an error and save the stack trace for | |
// debugging when in debug mode. Creating a stack trace is too expensive | |
// to do unconditionally. | |
} | |
disposeVariables() { | |
for (const varName in this.state.registeredVariables) { | |
const v = this.state.registeredVariables[varName]; | |
this.disposeVariable(v); | |
} | |
} | |
disposeVariable(v) { | |
this.disposeTensor(v); | |
if (this.state.registeredVariables[v.name] != null) { | |
delete this.state.registeredVariables[v.name]; | |
} | |
} | |
memory() { | |
const info = this.backend.memory(); | |
info.numTensors = this.state.numTensors; | |
info.numDataBuffers = this.state.numDataBuffers; | |
info.numBytes = this.state.numBytes; | |
if (this.state.numStringTensors > 0) { | |
info.unreliable = true; | |
if (info.reasons == null) { | |
info.reasons = []; | |
} | |
info.reasons.push('Memory usage by string tensors is approximate ' + | |
'(2 |