Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created August 30, 2019 17:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NMZivkovic/c1502211ae5f9ff083245ccfc31a148a to your computer and use it in GitHub Desktop.
Save NMZivkovic/c1502211ae5f9ff083245ccfc31a148a to your computer and use it in GitHub Desktop.
train_step_signature = [
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]
@tf.function(input_signature=train_step_signature)
def train_step(input_language, target_language):
target_input = target_language[:, :-1]
tartet_output = target_language[:, 1:]
# Create masks
encoder_padding_mask = maskHandler.padding_mask(input_language)
decoder_padding_mask = maskHandler.padding_mask(input_language)
look_ahead_mask = maskHandler.look_ahead_mask(tf.shape(target_language)[1])
decoder_target_padding_mask = maskHandler.padding_mask(target_language)
combined_mask = tf.maximum(decoder_target_padding_mask, look_ahead_mask)
# Run training step
with tf.GradientTape() as tape:
predictions, _ = transformer(input_language, target_input, True, encoder_padding_mask, combined_mask, decoder_padding_mask)
total_loss = padded_loss_function(tartet_output, predictions)
gradients = tape.gradient(total_loss, transformer.trainable_variables)
optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
training_loss(total_loss)
training_accuracy(tartet_output, predictions)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment