Last active
May 17, 2019 09:01
-
-
Save ChunML/a14c53ff9c8e5033ae094c1d10fd5d9d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import * as tf from '@tensorflow/tfjs'; | |
import React from 'react'; | |
import ReactDOM from 'react-dom'; | |
class CharacterTable { | |
constructor(chars) { | |
// chars must be a list of unique characters | |
this.chars = chars; | |
this.charIndices = {}; | |
this.indicesChar = {}; | |
this.size = this.chars.length; | |
// Create conversion dicts | |
for (let i = 0; i < this.size; i++) { | |
const char = this.chars[i]; | |
if (this.charIndices[char]) { | |
throw new Error(`Duplicate character ${char}`); | |
} | |
this.charIndices[char] = i; | |
this.indicesChar[i] = char; | |
} | |
} | |
encode(str, numRows) { | |
// create an empty tf.TensorBuffer | |
// can be converted to Tensor by .toTensor() | |
const buf = tf.buffer([numRows, this.size]); | |
for (let i = 0; i < str.length; i++) { | |
const char = str[i]; | |
if (this.charIndices[char] === null) { | |
throw new Error(`Unknown character ${char}`); | |
} | |
buf.set(1, i, this.charIndices[char]); | |
} | |
return buf.toTensor().as2D(numRows, this.size); | |
} | |
encodeBatch(strings, numRows) { | |
const numExamples = strings.length; | |
const buf = tf.buffer([numExamples, numRows, this.size]); | |
for (let i = 0; i < numExamples; i++) { | |
const str = strings[i]; | |
for (let j = 0; j < str.length; j++) { | |
const char = str[j]; | |
if (this.charIndices[char] === null) { | |
throw new Error(`Unknown character ${char}`); | |
} | |
buf.set(1, i, j, this.charIndices[char]); | |
} | |
} | |
return buf.toTensor().as3D(numExamples, numRows, this.size); | |
} | |
decode(x, computeArgmax = true) { | |
return tf.tidy(() => { | |
if (computeArgmax) { | |
x = x.argMax(1); | |
} | |
// get values from Tensor | |
const xData = x.dataSync(); | |
// let outputWord = ''; | |
// for (const index of Array.from(xData)) { | |
// outputWord += this.indicesChar[index]; | |
// } | |
// can I use reduce on that? | |
const outputWord = Array.from(xData) | |
.reduce((output, index) => output + this.indicesChar[index], ''); | |
return outputWord; | |
}); | |
} | |
} | |
function generateData(digits, numExamples, invert) { | |
const digitArray = ['0', '1', '2', | |
'3', '4', '5', | |
'6', '7', '8', '9']; | |
const arraySize = digitArray.length; | |
const output = []; | |
// maximum number of (2 numbers and 1 operand) | |
const maxLen = 2 * digits + 1; | |
// function to generate one number | |
const f = () => { | |
let str = ''; | |
while (str.length < digits) { | |
const index = Math.floor(Math.random() * arraySize); | |
str += digitArray[index]; | |
} | |
return Number.parseInt(str); | |
} | |
const seen = new Set(); | |
while (output.length < numExamples) { | |
const a = f(); | |
const b = f(); | |
const sorted = b > a ? [a, b] : [b, a]; | |
const key = sorted[0] + '`' + sorted[1]; | |
if (seen.has(key)) { | |
continue; | |
} | |
seen.add(key); | |
const q = `${a}+${b}`; | |
const query = q + ' '.repeat(maxLen - q.length); | |
let ans = (a + b).toString(); | |
ans += ' '.repeat(digits + 1 - ans.length); | |
if (invert) { | |
throw new Error('Not implemented yet!') | |
} | |
output.push([query, ans]); | |
} | |
return output; | |
} | |
function convertDataToTensors(data, charTable, digits) { | |
const maxLen = 2 * digits + 1; | |
const questions = data.map(datum => datum[0]); | |
const answers = data.map(datum => datum[1]); | |
return [ | |
charTable.encodeBatch(questions, maxLen), | |
charTable.encodeBatch(answers, digits + 1) | |
]; | |
} | |
function createModel(layers, hiddenSize, rnnType, digits, vocabSize) { | |
const maxLen = 2 * digits + 1; | |
const model = tf.sequential(); | |
switch(rnnType) { | |
case 'SimpleRNN': | |
model.add(tf.layers.simpleRNN({ | |
units: hiddenSize, | |
recurrentInitializer: 'glorotNormal', | |
inputShape: [maxLen, vocabSize] | |
})); | |
break; | |
case 'GRU': | |
model.add(tf.layers.gru({ | |
units: hiddenSize, | |
recurrentInitializer: 'glorotNormal', | |
inputShape: [maxLen, vocabSize] | |
})); | |
break; | |
case 'LSTM': | |
model.add(tf.layers.lstm({ | |
units: hiddenSize, | |
recurrentInitializer: 'glorotNormal', | |
inputShape: [maxLen, vocabSize] | |
})); | |
break; | |
default: | |
throw new Error(`Unsupported RNN type ${rnnType}`); | |
} | |
model.add(tf.layers.repeatVector({n: digits + 1})); | |
switch (rnnType) { | |
case 'SimpleRNN': | |
model.add(tf.layers.simpleRNN({ | |
units: hiddenSize, | |
recurrentInitializer: 'glorotNormal', | |
returnSequences: true | |
})); | |
break; | |
case 'GRU': | |
model.add(tf.layers.gru({ | |
units: hiddenSize, | |
recurrentInitializer: 'glorotNormal', | |
returnSequences: true | |
})); | |
break; | |
case 'LSTM': | |
model.add(tf.layers.lstm({ | |
units: hiddenSize, | |
recurrentInitializer: 'glorotNormal', | |
returnSequences: true | |
})); | |
break; | |
default: | |
throw new Error(`Unsupported RNN type ${rnnType}`); | |
} | |
model.add(tf.layers.timeDistributed({ | |
layer: tf.layers.dense({units: vocabSize}) | |
})); | |
model.add(tf.layers.activation({ | |
activation: 'softmax' | |
})); | |
model.compile({ | |
loss: 'categoricalCrossentropy', | |
optimizer: 'adam', | |
metrics: ['accuracy'] | |
}); | |
return model; | |
} | |
class AdditionRNN { | |
constructor(digits, trainingSize, rnnType, layers, hiddenSize) { | |
const chars = '0123456789+ '; | |
this.charTable = new CharacterTable(chars); | |
console.log('Generating training data...'); | |
const data = generateData(digits, trainingSize, false); | |
const split = Math.floor(trainingSize * 0.9); | |
this.trainData = data.slice(0, split); | |
this.testData = data.slice(split); | |
[this.trainXs, this.trainYs] = convertDataToTensors( | |
this.trainData, this.charTable, digits); | |
[this.testXs, this.testYs] = convertDataToTensors( | |
this.testData, this.charTable, digits); | |
this.model = createModel( | |
layers, hiddenSize, rnnType, digits, chars.length); | |
} | |
async train(iterations, batchSize, numTestExamples, callback) { | |
const lossValues = [[], []]; | |
const accuracyValues = [[], []]; | |
for (let i = 0; i < iterations; i++) { | |
const beginMs = performance.now(); | |
const history = await this.model.fit(this.trainXs, this.trainYs, { | |
epochs: 1, | |
batchSize, | |
validationData: [this.testXs, this.testYs], | |
yieldEvery: 'epoch' | |
}); | |
const elapsedMs = performance.now(); | |
const modelFitTime = elapsedMs / 1000; | |
const trainLoss = history.history['loss'][0]; | |
const trainAcc = history.history['acc'][0]; | |
const valLoss = history.history['val_loss'][0]; | |
const valAcc = history.history['val_acc'][0]; | |
lossValues[0].push({'x': i, 'y': trainLoss}); | |
lossValues[1].push({'x': i, 'y': valLoss}); | |
accuracyValues[0].push({'x': i, 'y': trainAcc}); | |
accuracyValues[1].push({'x': i, 'y': valAcc}); | |
const examples = []; | |
const isCorrect = []; | |
const testXsForDisplay = this.testXs.slice( | |
[0, 0, 0], | |
[numTestExamples, this.testXs.shape[1], this.testXs.shape[2]]); | |
tf.tidy(() => { | |
const predictOut = this.model.predict(testXsForDisplay); | |
for (let k = 0; k < numTestExamples; k++) { | |
const scores = predictOut.slice( | |
[k, 0, 0], | |
[1, predictOut.shape[1], predictOut.shape[2]]) | |
.as2D(predictOut.shape[1], predictOut.shape[2]); | |
const decoded = this.charTable.decode(scores); | |
examples.push(this.testData[k][0] + ' = ' + decoded); | |
isCorrect.push(this.testData[k][1].trim() === decoded.trim()); | |
} | |
}); | |
callback(trainLoss, trainAcc, examples, isCorrect); | |
console.log(examples); | |
} | |
} | |
} | |
async function runAdditionRNN(callback) { | |
const digits = 4; | |
const trainingSize = 10000; | |
const rnnType = 'LSTM'; | |
const layers = 2; | |
const hiddenSize = 128; | |
const batchSize = 32; | |
const trainIterations = 10000; | |
const numTestExamples = 50; | |
const trainingSizeLimit = Math.pow(Math.pow(10, digits), 2); | |
if (trainingSize > trainingSizeLimit) { | |
console.log('Training size is too large'); | |
return; | |
} | |
const rnn = new AdditionRNN( | |
digits, trainingSize, rnnType, layers, hiddenSize); | |
await rnn.train(trainIterations, batchSize, numTestExamples, callback); | |
} | |
// runAdditionRNN(); | |
class App extends React.Component { | |
constructor(props) { | |
super(props); | |
this.state = { | |
loss: '', | |
accuracy: '', | |
examples: [], | |
isCorrect: [] | |
}; | |
} | |
componentDidMount() { | |
runAdditionRNN((loss, accuracy, examples, isCorrect) => { | |
this.setState({ | |
loss, | |
accuracy, | |
examples, | |
isCorrect | |
}); | |
}); | |
} | |
render() { | |
const { loss, accuracy, examples, isCorrect } = this.state; | |
const examplesLi = examples.map((ex, i) => ( | |
<li key={ i }> | |
{ ex } <span style={{color: isCorrect[i] ? 'rgb(20, 230, 30)' : 'rgb(250, 20, 20)'}}>{isCorrect[i] ? '✔' : '✗'}</span> | |
</li> | |
)) | |
return ( | |
<React.Fragment> | |
<div> | |
<p>Loss value:</p> | |
<p>{ loss }</p> | |
<p>Accuracy value:</p> | |
<p>{ accuracy }</p> | |
<p>Test examples:</p> | |
<ul> | |
{ examplesLi } | |
</ul> | |
</div> | |
</React.Fragment> | |
); | |
} | |
} | |
ReactDOM.render( | |
<App />, | |
document.getElementById('app') | |
); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment