Skip to content

Instantly share code, notes, and snippets.

@extremeheat
Last active June 30, 2021 08:28
Show Gist options
  • Save extremeheat/d1a5c31b5e3f316038f425a0a8922337 to your computer and use it in GitHub Desktop.
Save extremeheat/d1a5c31b5e3f316038f425a0a8922337 to your computer and use it in GitHub Desktop.
import { py, python, PyClass } from "JSPyBridge"
const tf = await python('tensorflow')
class KerasCallback extends PyClass {
constructor() {
super(tf.keras.callbacks.Callback())
}
on_epoch_end(epoch, logs) {
if (logs.loss < 0.4) {
console.log('\nReached 60% accuracy so cancelling training!')
py`setattr(${this.superclass()}.model, 'stop_training', True)`
}
}
}
const mnist = await tf.keras.datasets.fashion_mnist
const [[training_images, training_labels], [test_images, test_labels]] = await mnist.load_data()
const trainingImages = await py`${training_images} / 255.0`
const testImages = await py`${test_images} / 255.0`
console.log('traiing', await trainingImages.shape, await training_labels.shape)
console.log('test', await testImages.shape, await test_labels.shape)
const model = await tf.keras.models.Sequential([
await tf.keras.layers.Flatten(),
await tf.keras.layers.Dense$(512, { activation: await tf.nn.relu }),
await tf.keras.layers.Dense$(10, { activation: await tf.nn.softmax })
])
await model.compile$({ optimizer: 'adam', loss: 'sparse_categorical_crossentropy' })
await model.fit$(trainingImages, await training_labels, { epochs: 5, callbacks: [new KerasCallback()], $timeout: Infinity })
python.exit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment