Skip to content

Instantly share code, notes, and snippets.

@ChunML
Last active May 17, 2019 09:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ChunML/a14c53ff9c8e5033ae094c1d10fd5d9d to your computer and use it in GitHub Desktop.
Save ChunML/a14c53ff9c8e5033ae094c1d10fd5d9d to your computer and use it in GitHub Desktop.
import * as tf from '@tensorflow/tfjs';
import React from 'react';
import ReactDOM from 'react-dom';
class CharacterTable {
constructor(chars) {
// chars must be a list of unique characters
this.chars = chars;
this.charIndices = {};
this.indicesChar = {};
this.size = this.chars.length;
// Create conversion dicts
for (let i = 0; i < this.size; i++) {
const char = this.chars[i];
if (this.charIndices[char]) {
throw new Error(`Duplicate character ${char}`);
}
this.charIndices[char] = i;
this.indicesChar[i] = char;
}
}
encode(str, numRows) {
// create an empty tf.TensorBuffer
// can be converted to Tensor by .toTensor()
const buf = tf.buffer([numRows, this.size]);
for (let i = 0; i < str.length; i++) {
const char = str[i];
if (this.charIndices[char] === null) {
throw new Error(`Unknown character ${char}`);
}
buf.set(1, i, this.charIndices[char]);
}
return buf.toTensor().as2D(numRows, this.size);
}
encodeBatch(strings, numRows) {
const numExamples = strings.length;
const buf = tf.buffer([numExamples, numRows, this.size]);
for (let i = 0; i < numExamples; i++) {
const str = strings[i];
for (let j = 0; j < str.length; j++) {
const char = str[j];
if (this.charIndices[char] === null) {
throw new Error(`Unknown character ${char}`);
}
buf.set(1, i, j, this.charIndices[char]);
}
}
return buf.toTensor().as3D(numExamples, numRows, this.size);
}
decode(x, computeArgmax = true) {
return tf.tidy(() => {
if (computeArgmax) {
x = x.argMax(1);
}
// get values from Tensor
const xData = x.dataSync();
// let outputWord = '';
// for (const index of Array.from(xData)) {
// outputWord += this.indicesChar[index];
// }
// can I use reduce on that?
const outputWord = Array.from(xData)
.reduce((output, index) => output + this.indicesChar[index], '');
return outputWord;
});
}
}
function generateData(digits, numExamples, invert) {
const digitArray = ['0', '1', '2',
'3', '4', '5',
'6', '7', '8', '9'];
const arraySize = digitArray.length;
const output = [];
// maximum number of (2 numbers and 1 operand)
const maxLen = 2 * digits + 1;
// function to generate one number
const f = () => {
let str = '';
while (str.length < digits) {
const index = Math.floor(Math.random() * arraySize);
str += digitArray[index];
}
return Number.parseInt(str);
}
const seen = new Set();
while (output.length < numExamples) {
const a = f();
const b = f();
const sorted = b > a ? [a, b] : [b, a];
const key = sorted[0] + '`' + sorted[1];
if (seen.has(key)) {
continue;
}
seen.add(key);
const q = `${a}+${b}`;
const query = q + ' '.repeat(maxLen - q.length);
let ans = (a + b).toString();
ans += ' '.repeat(digits + 1 - ans.length);
if (invert) {
throw new Error('Not implemented yet!')
}
output.push([query, ans]);
}
return output;
}
function convertDataToTensors(data, charTable, digits) {
const maxLen = 2 * digits + 1;
const questions = data.map(datum => datum[0]);
const answers = data.map(datum => datum[1]);
return [
charTable.encodeBatch(questions, maxLen),
charTable.encodeBatch(answers, digits + 1)
];
}
function createModel(layers, hiddenSize, rnnType, digits, vocabSize) {
const maxLen = 2 * digits + 1;
const model = tf.sequential();
switch(rnnType) {
case 'SimpleRNN':
model.add(tf.layers.simpleRNN({
units: hiddenSize,
recurrentInitializer: 'glorotNormal',
inputShape: [maxLen, vocabSize]
}));
break;
case 'GRU':
model.add(tf.layers.gru({
units: hiddenSize,
recurrentInitializer: 'glorotNormal',
inputShape: [maxLen, vocabSize]
}));
break;
case 'LSTM':
model.add(tf.layers.lstm({
units: hiddenSize,
recurrentInitializer: 'glorotNormal',
inputShape: [maxLen, vocabSize]
}));
break;
default:
throw new Error(`Unsupported RNN type ${rnnType}`);
}
model.add(tf.layers.repeatVector({n: digits + 1}));
switch (rnnType) {
case 'SimpleRNN':
model.add(tf.layers.simpleRNN({
units: hiddenSize,
recurrentInitializer: 'glorotNormal',
returnSequences: true
}));
break;
case 'GRU':
model.add(tf.layers.gru({
units: hiddenSize,
recurrentInitializer: 'glorotNormal',
returnSequences: true
}));
break;
case 'LSTM':
model.add(tf.layers.lstm({
units: hiddenSize,
recurrentInitializer: 'glorotNormal',
returnSequences: true
}));
break;
default:
throw new Error(`Unsupported RNN type ${rnnType}`);
}
model.add(tf.layers.timeDistributed({
layer: tf.layers.dense({units: vocabSize})
}));
model.add(tf.layers.activation({
activation: 'softmax'
}));
model.compile({
loss: 'categoricalCrossentropy',
optimizer: 'adam',
metrics: ['accuracy']
});
return model;
}
class AdditionRNN {
constructor(digits, trainingSize, rnnType, layers, hiddenSize) {
const chars = '0123456789+ ';
this.charTable = new CharacterTable(chars);
console.log('Generating training data...');
const data = generateData(digits, trainingSize, false);
const split = Math.floor(trainingSize * 0.9);
this.trainData = data.slice(0, split);
this.testData = data.slice(split);
[this.trainXs, this.trainYs] = convertDataToTensors(
this.trainData, this.charTable, digits);
[this.testXs, this.testYs] = convertDataToTensors(
this.testData, this.charTable, digits);
this.model = createModel(
layers, hiddenSize, rnnType, digits, chars.length);
}
async train(iterations, batchSize, numTestExamples, callback) {
const lossValues = [[], []];
const accuracyValues = [[], []];
for (let i = 0; i < iterations; i++) {
const beginMs = performance.now();
const history = await this.model.fit(this.trainXs, this.trainYs, {
epochs: 1,
batchSize,
validationData: [this.testXs, this.testYs],
yieldEvery: 'epoch'
});
const elapsedMs = performance.now();
const modelFitTime = elapsedMs / 1000;
const trainLoss = history.history['loss'][0];
const trainAcc = history.history['acc'][0];
const valLoss = history.history['val_loss'][0];
const valAcc = history.history['val_acc'][0];
lossValues[0].push({'x': i, 'y': trainLoss});
lossValues[1].push({'x': i, 'y': valLoss});
accuracyValues[0].push({'x': i, 'y': trainAcc});
accuracyValues[1].push({'x': i, 'y': valAcc});
const examples = [];
const isCorrect = [];
const testXsForDisplay = this.testXs.slice(
[0, 0, 0],
[numTestExamples, this.testXs.shape[1], this.testXs.shape[2]]);
tf.tidy(() => {
const predictOut = this.model.predict(testXsForDisplay);
for (let k = 0; k < numTestExamples; k++) {
const scores = predictOut.slice(
[k, 0, 0],
[1, predictOut.shape[1], predictOut.shape[2]])
.as2D(predictOut.shape[1], predictOut.shape[2]);
const decoded = this.charTable.decode(scores);
examples.push(this.testData[k][0] + ' = ' + decoded);
isCorrect.push(this.testData[k][1].trim() === decoded.trim());
}
});
callback(trainLoss, trainAcc, examples, isCorrect);
console.log(examples);
}
}
}
async function runAdditionRNN(callback) {
const digits = 4;
const trainingSize = 10000;
const rnnType = 'LSTM';
const layers = 2;
const hiddenSize = 128;
const batchSize = 32;
const trainIterations = 10000;
const numTestExamples = 50;
const trainingSizeLimit = Math.pow(Math.pow(10, digits), 2);
if (trainingSize > trainingSizeLimit) {
console.log('Training size is too large');
return;
}
const rnn = new AdditionRNN(
digits, trainingSize, rnnType, layers, hiddenSize);
await rnn.train(trainIterations, batchSize, numTestExamples, callback);
}
// runAdditionRNN();
class App extends React.Component {
constructor(props) {
super(props);
this.state = {
loss: '',
accuracy: '',
examples: [],
isCorrect: []
};
}
componentDidMount() {
runAdditionRNN((loss, accuracy, examples, isCorrect) => {
this.setState({
loss,
accuracy,
examples,
isCorrect
});
});
}
render() {
const { loss, accuracy, examples, isCorrect } = this.state;
const examplesLi = examples.map((ex, i) => (
<li key={ i }>
{ ex } <span style={{color: isCorrect[i] ? 'rgb(20, 230, 30)' : 'rgb(250, 20, 20)'}}>{isCorrect[i] ? '✔' : '✗'}</span>
</li>
))
return (
<React.Fragment>
<div>
<p>Loss value:</p>
<p>{ loss }</p>
<p>Accuracy value:</p>
<p>{ accuracy }</p>
<p>Test examples:</p>
<ul>
{ examplesLi }
</ul>
</div>
</React.Fragment>
);
}
}
ReactDOM.render(
<App />,
document.getElementById('app')
);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment