Skip to content

Instantly share code, notes, and snippets.

@mwdchang
Created June 4, 2018 02:49
Show Gist options
  • Save mwdchang/2defb6945c17c1836613a8a9d4fb4052 to your computer and use it in GitHub Desktop.
Save mwdchang/2defb6945c17c1836613a8a9d4fb4052 to your computer and use it in GitHub Desktop.
Tensorflow JS core test
<!DOCTYPE html>
<html lang="en">
<head>
<script src="tf.min.js"></script>
</head>
<body>
</body>
<script>
const LEARNING_RATE = 0.07;
const optimizer = tf.train.sgd(LEARNING_RATE);
// Variable and weights definition
const L1w = tf.variable(tf.randomNormal([2, 2]));
const L1b = tf.variable(tf.randomNormal([2]));
const L2w = tf.variable(tf.randomNormal([2, 5]));
const L2b = tf.variable(tf.randomNormal([5]));
const L3w = tf.variable(tf.randomNormal([5, 2]));
const L3b = tf.variable(tf.randomNormal([2]));
const model = (xs) => {
const l1 = tf.tanh(xs.matMul(L1w).add(L1b));
const l2 = tf.tanh(l1.matMul(L2w).add(L2b));
const l3 = tf.tanh(l2.matMul(L3w).add(L3b));
return l3;
};
// Loss function
const loss = (labels, ys) => {
// return tf.losses.meanSquaredError(labels, ys).mean();
return tf.losses.softmaxCrossEntropy(labels.asType('float32'), ys).mean();
};
// Dummy data if the product is even (0) or odd (1)
const createTrainBatch = (num) => {
const testData = [];
const testLabel = [];
for (let i=0; i < num; i++) {
const a = 1 + Math.floor(Math.random()*10);
const b = 1 + Math.floor(Math.random()*10);
testData.push(a);
testData.push(b);
if (a > b) {
testLabel.push(0);
} else {
testLabel.push(1);
}
}
return [tf.tensor(testData).as2D(-1, 2), tf.oneHot(tf.tensor1d(testLabel, 'int32'), 2), testLabel];
};
// Traning with SGD/backprop
async function train() {
L2w.print();
for (let i=0; i < 50; i++) {
optimizer.minimize(() => {
const [input, label] = createTrainBatch(40);
return loss(label, model(input));
}, true);
await tf.nextFrame();
}
}
// Predict
async function run() {
await train();
const axis = 1;
const [input, label, origLabel] = createTrainBatch(10);
console.log('A', origLabel);
console.log('P', Array.from(model(input).argMax(axis).dataSync()));
}
run();
</script>
</html>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment