train_step_signature = [ | |
tf.TensorSpec(shape=(None, None), dtype=tf.int64), | |
tf.TensorSpec(shape=(None, None), dtype=tf.int64), | |
] | |
@tf.function(input_signature=train_step_signature) | |
def train_step(input_language, target_language): | |
target_input = target_language[:, :-1] | |
tartet_output = target_language[:, 1:] | |
# Create masks | |
encoder_padding_mask = maskHandler.padding_mask(input_language) | |
decoder_padding_mask = maskHandler.padding_mask(input_language) | |
look_ahead_mask = maskHandler.look_ahead_mask(tf.shape(target_language)[1]) | |
decoder_target_padding_mask = maskHandler.padding_mask(target_language) | |
combined_mask = tf.maximum(decoder_target_padding_mask, look_ahead_mask) | |
# Run training step | |
with tf.GradientTape() as tape: | |
predictions, _ = transformer(input_language, target_input, True, encoder_padding_mask, combined_mask, decoder_padding_mask) | |
total_loss = padded_loss_function(tartet_output, predictions) | |
gradients = tape.gradient(total_loss, transformer.trainable_variables) | |
optimizer.apply_gradients(zip(gradients, transformer.trainable_variables)) | |
training_loss(total_loss) | |
training_accuracy(tartet_output, predictions) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment