Skip to content

Instantly share code, notes, and snippets.

@VXU1230
Created March 19, 2019 18:04
Show Gist options
  • Save VXU1230/a8254f3013fad489011ca97aae444b38 to your computer and use it in GitHub Desktop.
Save VXU1230/a8254f3013fad489011ca97aae444b38 to your computer and use it in GitHub Desktop.
train a batch
@tf.function
def train_step(model, loss_fn, optimizer, target, context, label):
with tf.GradientTape() as tape:
predictions = model(target, context)
batch_loss = loss_fn(label, predictions)
gradients = tape.gradient(batch_loss, model.trainable_variables)
c_gradients = [tf.clip_by_value(g, -5., 5.) for g in gradients if g is not None]
optimizer.apply_gradients(zip(c_gradients, model.trainable_variables))
g2 = 0
for g in c_gradients:
g2 += tf.square(tf.reduce_mean(g))
grad_norm = tf.sqrt(g2)
return batch_loss, grad_norm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment