Skip to content

Instantly share code, notes, and snippets.

@mikeemoo
Last active June 22, 2018 20:13
Show Gist options
  • Save mikeemoo/b4affe711d56bd9828e8db7ce73ae13c to your computer and use it in GitHub Desktop.
Save mikeemoo/b4affe711d56bd9828e8db7ce73ae13c to your computer and use it in GitHub Desktop.
fast-style-transfer-tensorflow-js
import * as tf from "@tensorflow/tfjs-core";
const CKPTSDIR =
document.URL.substr(0, document.URL.lastIndexOf("/")) + "/ckpts/";
export const loadStyle = (id, ckptsDir = CKPTSDIR) =>
fetch(ckptsDir + id + "/manifest.json")
.then(r => r.json())
.then(manifest => {
const variableNames = Object.keys(manifest);
return Promise.all(
variableNames.map(name =>
fetch(ckptsDir + id + "/" + name)
.then(r => r.arrayBuffer())
.then(values =>
tf.Tensor.make(manifest[name].shape, {
values: new Float32Array(values)
})
)
)
).then(variables =>
variables.reduce(
(acc, item, index) => (acc[variableNames[index]] = item) && acc,
{}
)
);
});
export const memory = () => console.log(tf.memory());
export const disposeStyle = style =>
Object.values(style).forEach(s => s.dispose());
const pipe = (...ops) => x => ops.reduce((prev, func) => func(prev), x);
export const predict = (image, style) =>
tf.tidy(() =>
pipe(
image => tf.fromPixels(image).toFloat(),
convLayer(style, 1, true, 0),
convLayer(style, 2, true, 3),
convLayer(style, 2, true, 6),
residualBlock(style, 9),
residualBlock(style, 15),
residualBlock(style, 21),
residualBlock(style, 27),
residualBlock(style, 33),
convTransposeLayer(style, 64, 2, 39),
convTransposeLayer(style, 32, 2, 42),
convLayer(style, 1, false, 45),
input => tf.tanh(input),
input => tf.mul(tf.scalar(150), input),
input => tf.add(tf.scalar(255 / 2), input),
input => tf.clipByValue(input, 0, 255)
)(image)
);
const varName = varId => `Variable${varId > 0 ? `_${varId.toString()}` : ``}`;
const convLayer = (style, strides, relu, varId) =>
tf.tidy(() =>
pipe(
input => tf.conv2d(input, style[varName(varId)], 1, strides),
instanceNorm(style, varId + 1),
input => (relu ? tf.relu(input) : input)
)
);
const convTransposeLayer = (style, numFilters, strides, varId) => input => {
const [height, width] = input.shape;
const newRows = height * strides;
const newCols = width * strides;
const newShape = [newRows, newCols, numFilters];
return tf.tidy(() =>
pipe(
input =>
tf.conv2dTranspose(
input,
style[varName(varId)],
newShape,
strides,
"same"
),
instanceNorm(style, varId + 1),
tf.relu
)(input)
);
};
const residualBlock = (style, varId) => input =>
tf.tidy(() =>
pipe(
convLayer(style, 1, true, varId),
convLayer(style, 1, false, varId + 3),
i => tf.addStrict(i, input)
)(input)
);
const instanceNorm = (style, varId) => input =>
tf.tidy(() => {
const [height, width, inDepth] = input.shape;
const moments = tf.moments(input, [0, 1]);
const mu = moments.mean;
const sigmaSq = moments.variance;
const shift = style[varName(varId)];
const scale = style[varName(varId + 1)];
const epsilon = 1e-3;
const normalized = tf.div(
tf.sub(input, mu),
tf.sqrt(tf.add(sigmaSq, tf.scalar(epsilon)))
);
const shifted = tf.add(tf.mul(scale, normalized), shift);
return shifted.as3D(height, width, inDepth);
});
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment