Skip to content

Instantly share code, notes, and snippets.

@random-forests
Created March 5, 2019 23:01
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save random-forests/05d7cb7469e50b53f83e44486677d783 to your computer and use it in GitHub Desktop.
Save random-forests/05d7cb7469e50b53f83e44486677d783 to your computer and use it in GitHub Desktop.
# See https://github.com/tensorflow/docs/blob/master/site/en/r2/guide/autograph.ipynb
def train_one_step(model, optimizer, x, y):
with tf.GradientTape() as tape:
logits = model(x)
loss = compute_loss(y, logits)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
compute_accuracy(y, logits)
return loss
def train(model, optimizer):
train_ds = mnist_dataset()
step = 0
loss = 0.0
for x, y in train_ds:
step += 1
loss = train_one_step(model, optimizer, x, y)
if tf.equal(step % 10, 0):
tf.print('Step', step, ': loss',
loss, '; accuracy', compute_accuracy.result())
return step, loss, accuracy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment