Skip to content

Instantly share code, notes, and snippets.

@ChunML
Created May 6, 2019 04:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ChunML/84f6e03a91451485d345a112078f9d0d to your computer and use it in GitHub Desktop.
Save ChunML/84f6e03a91451485d345a112078f9d0d to your computer and use it in GitHub Desktop.
@tf.function
def train_step(source_seq, target_seq_in, target_seq_out):
with tf.GradientTape() as tape:
padding_mask = 1 - tf.cast(tf.equal(source_seq, 0), dtype=tf.float32)
# Manually add two more dimentions
# so that the mask's shape becomes (batch_size, 1, 1, seq_len)
padding_mask = tf.expand_dims(padding_mask, axis=1)
padding_mask = tf.expand_dims(padding_mask, axis=1)
encoder_output = encoder(source_seq, padding_mask)
decoder_output = decoder(target_seq_in, encoder_output, padding_mask)
loss = loss_func(target_seq_out, decoder_output)
variables = encoder.trainable_variables + decoder.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment