Created
June 4, 2018 02:49
-
-
Save mwdchang/2defb6945c17c1836613a8a9d4fb4052 to your computer and use it in GitHub Desktop.
Tensorflow JS core test
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<script src="tf.min.js"></script> | |
</head> | |
<body> | |
</body> | |
<script> | |
const LEARNING_RATE = 0.07; | |
const optimizer = tf.train.sgd(LEARNING_RATE); | |
// Variable and weights definition | |
const L1w = tf.variable(tf.randomNormal([2, 2])); | |
const L1b = tf.variable(tf.randomNormal([2])); | |
const L2w = tf.variable(tf.randomNormal([2, 5])); | |
const L2b = tf.variable(tf.randomNormal([5])); | |
const L3w = tf.variable(tf.randomNormal([5, 2])); | |
const L3b = tf.variable(tf.randomNormal([2])); | |
const model = (xs) => { | |
const l1 = tf.tanh(xs.matMul(L1w).add(L1b)); | |
const l2 = tf.tanh(l1.matMul(L2w).add(L2b)); | |
const l3 = tf.tanh(l2.matMul(L3w).add(L3b)); | |
return l3; | |
}; | |
// Loss function | |
const loss = (labels, ys) => { | |
// return tf.losses.meanSquaredError(labels, ys).mean(); | |
return tf.losses.softmaxCrossEntropy(labels.asType('float32'), ys).mean(); | |
}; | |
// Dummy data if the product is even (0) or odd (1) | |
const createTrainBatch = (num) => { | |
const testData = []; | |
const testLabel = []; | |
for (let i=0; i < num; i++) { | |
const a = 1 + Math.floor(Math.random()*10); | |
const b = 1 + Math.floor(Math.random()*10); | |
testData.push(a); | |
testData.push(b); | |
if (a > b) { | |
testLabel.push(0); | |
} else { | |
testLabel.push(1); | |
} | |
} | |
return [tf.tensor(testData).as2D(-1, 2), tf.oneHot(tf.tensor1d(testLabel, 'int32'), 2), testLabel]; | |
}; | |
// Traning with SGD/backprop | |
async function train() { | |
L2w.print(); | |
for (let i=0; i < 50; i++) { | |
optimizer.minimize(() => { | |
const [input, label] = createTrainBatch(40); | |
return loss(label, model(input)); | |
}, true); | |
await tf.nextFrame(); | |
} | |
} | |
// Predict | |
async function run() { | |
await train(); | |
const axis = 1; | |
const [input, label, origLabel] = createTrainBatch(10); | |
console.log('A', origLabel); | |
console.log('P', Array.from(model(input).argMax(axis).dataSync())); | |
} | |
run(); | |
</script> | |
</html> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment