Skip to content

Instantly share code, notes, and snippets.

@ChunML
Created April 30, 2019 03:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ChunML/1d7fc8b17823f9861d997217bbe718e5 to your computer and use it in GitHub Desktop.
Save ChunML/1d7fc8b17823f9861d997217bbe718e5 to your computer and use it in GitHub Desktop.
@tf.function
def train_step(source_seq, target_seq_in, target_seq_out):
with tf.GradientTape() as tape:
# padding_mask of the source sequence
# to be used in the Encoder
# and the middle Multi-Head Attention of the Decoder
padding_mask = 1 - tf.cast(tf.equal(source_seq, 0), dtype=tf.float32)
encoder_output = encoder(source_seq, padding_mask)
decoder_output = decoder(target_seq_in, encoder_output, padding_mask)
loss = loss_func(target_seq_out, decoder_output)
variables = encoder.trainable_variables + decoder.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment