Skip to content

Instantly share code, notes, and snippets.

@stephkoltun
Created June 14, 2018 20:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save stephkoltun/28dbb6c05f4e92b94ad3861c6d197370 to your computer and use it in GitHub Desktop.
Save stephkoltun/28dbb6c05f4e92b94ad3861c6d197370 to your computer and use it in GitHub Desktop.
// Tiny TFJS train / predict example.
async function myFirstTfjs() {
// Create a simple model.
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
// Prepare the model for training: Specify the loss and the optimizer.
model.compile({
loss: 'meanSquaredError',
optimizer: 'sgd'
});
// Generate some synthetic data for training. (y = 2x - 1)
const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
// Train the model using the data.
await model.fit(xs, ys, {epochs: 250});
// Use the model to do inference on a data point the model hasn't seen.
// Should print approximately 39.
document.getElementById('micro_out_div').innerText +=
model.predict(tf.tensor2d([20], [1, 1]));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment