Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
train_step_signature = [
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]
@tf.function(input_signature=train_step_signature)
def train_step(input_language, target_language):
target_input = target_language[:, :-1]
tartet_output = target_language[:, 1:]
# Create masks
encoder_padding_mask = maskHandler.padding_mask(input_language)
decoder_padding_mask = maskHandler.padding_mask(input_language)
look_ahead_mask = maskHandler.look_ahead_mask(tf.shape(target_language)[1])
decoder_target_padding_mask = maskHandler.padding_mask(target_language)
combined_mask = tf.maximum(decoder_target_padding_mask, look_ahead_mask)
# Run training step
with tf.GradientTape() as tape:
predictions, _ = transformer(input_language, target_input, True, encoder_padding_mask, combined_mask, decoder_padding_mask)
total_loss = padded_loss_function(tartet_output, predictions)
gradients = tape.gradient(total_loss, transformer.trainable_variables)
optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
training_loss(total_loss)
training_accuracy(tartet_output, predictions)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment