Skip to content

Instantly share code, notes, and snippets.

@tehZevo
Last active September 15, 2018 19:06
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tehZevo/0f4ec45fb900c207491a7032b9d770ba to your computer and use it in GitHub Desktop.
Save tehZevo/0f4ec45fb900c207491a7032b9d770ba to your computer and use it in GitHub Desktop.
An example solving XOR problem with tfjs
var tf = require("@tensorflow/tfjs");
//code modified from:
//https://medium.com/tensorflow/a-gentle-introduction-to-tensorflow-js-dba2e5257702
//define our inputs (combinations of 2 bits, represented as 0s and 1s)
//https://js.tensorflow.org/api/0.12.0/#tensor
var xs = tf.tensor([[0, 0], [0, 1], [1, 0], [1, 1]]);
//define our outputs (xor operation, a simple non-linear problem)
//[0, 0] -> [0], [0, 1] -> [1], etc
var ys = tf.tensor([[0], [1], [1], [0]]);
//create a "sequential" model
//https://js.tensorflow.org/api/0.12.0/#sequential
var model = tf.sequential();
//add a dense/"fully-connected" layer to the model that takes inputs of size 2,
//contains 8 neurons, and uses "tanh" activation function
//https://js.tensorflow.org/api/0.12.0/#layers.dense
//see https://js.tensorflow.org/api/0.12.0/#layers.activation for other activations
model.add(tf.layers.dense({units: 8, inputDim: 2, activation: 'tanh'}));
//add a second dense layer with 1 neuron and "sigmoid" activation
model.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));
//compile our model using "adam" optimizer, binary crossentropy loss,
//and a learning rate of 0.1
//https://js.tensorflow.org/api/0.12.0/#class:Model
model.compile({optimizer: 'adam', loss: 'binaryCrossentropy', lr: 0.3});
//train the model for 5000 epochs (complete passes over the dataset),
//using a batch size of 4 (4 input/output pairs per iteration)
//https://towardsdatascience.com/epoch-vs-iterations-vs-batch-size-4dfb9c7ce9c9
model.fit(xs, ys, {
batchSize: 4,
epochs: 1000,
callbacks: {
//after every epoch, print the current loss
onEpochEnd: async (epoch, log) => {
console.log(`Epoch ${epoch}: loss = ${log.loss}`);
}
}
}).then(() =>
{
//after training, predict answers for our original input data
model.predict(xs).print();
//output should be close to [[0], [1], [1], [0]]
})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment