Skip to content

Instantly share code, notes, and snippets.

@alexrhogue
Created May 15, 2021 00:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alexrhogue/8e8da1ebd6df5979e4736a6836a47fc7 to your computer and use it in GitHub Desktop.
Save alexrhogue/8e8da1ebd6df5979e4736a6836a47fc7 to your computer and use it in GitHub Desktop.
import tf from '@tensorflow/tfjs';
import fs from 'fs';
const training = JSON.parse(fs.readFileSync('./data/training.json', { encoding: 'utf8'}));
const raw_data = training.data
const inputs = raw_data.map(d => d[0].reduce((p, c) => [...p, ...c], []))
const labels = raw_data.map(d => d[1].length === 0 ? [-1,-1,-1] : d[1]);
const inputTensor = tf.tensor2d(inputs, [inputs.length, 48]);
const labelTensor = tf.tensor2d(labels, [inputs.length, 3]);
// Create a sequential model
const model = tf.sequential();
// Add a single input layer
model.add(tf.layers.dense({units: 3, inputShape: [48], useBias: true}));
model.compile({
optimizer: tf.train.adam(),
loss: tf.losses.absoluteDifference,
metrics: ['mse'],
});
await model.fit(inputTensor, labelTensor, {
batchSize: 5000,
epochs: 50,
shuffle: true,
});
model.predict(tf.tensor2d([inputs[0]])).print();
console.log(inputs[0])
console.log(labels[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment