Skip to content

Instantly share code, notes, and snippets.

@ChunML
Created April 28, 2019 12:10
Show Gist options
  • Save ChunML/25507052c5831012ef7803b7708074aa to your computer and use it in GitHub Desktop.
Save ChunML/25507052c5831012ef7803b7708074aa to your computer and use it in GitHub Desktop.
@tf.function
def train_step(source_seq, target_seq_in, target_seq_out):
with tf.GradientTape() as tape:
encoder_output = encoder(source_seq)
decoder_output = decoder(target_seq_in, encoder_output)
loss = loss_func(target_seq_out, decoder_output)
variables = encoder.trainable_variables + decoder.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment