Skip to content

Instantly share code, notes, and snippets.

@NiloyPurkait
Last active August 15, 2020 13:04
Show Gist options
  • Save NiloyPurkait/7ff24a80ef69c7e000888c27e9e18dc5 to your computer and use it in GitHub Desktop.
Save NiloyPurkait/7ff24a80ef69c7e000888c27e9e18dc5 to your computer and use it in GitHub Desktop.
def train_step(inp, tar):
# targets shifted by 1 index position
tar_inp = tar[:, :-1]
tar_real = tar[:, 1:]
#Get encoding, combined and decoding masks
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
# Initialize Generator and Discriminator gradient tapes
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
# Get the predicted probabilities from generator
predictions, _ = generator(inp, tar_inp,
True,
enc_padding_mask,
combined_mask,
dec_padding_mask)
# Get predicted sequences
batch_pred = tf.argmax(predictions, axis=-1)
# Pad predicted sequences
batch_pred = tf.keras.preprocessing.sequence.pad_sequences(batch_pred, padding='post',
value=0, maxlen=tar.shape[-1])
# Get discriminator's predictions of real & generated output
disc_preds_real = discriminator([inp, tar], training=True)
disc_preds_fake = discriminator([inp, batch_pred], training=True)
# Calculate loss using discriminator and generator loss functions
d_loss = discriminator_loss(disc_preds_real, disc_preds_fake)
g_loss = generator_loss(disc_preds_fake)
# Get discriminator gradients and apply using optimizer
disc_grads = disc_tape.gradient(d_loss, discriminator.trainable_weights)
discriminator_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_weights))
# Get generator gradients and apply using optimizer
gen_grads = gen_tape.gradient(g_loss, generator.trainable_weights)
generator_optimizer.apply_gradients(zip(gen_grads, generator.trainable_weights))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment