Skip to content

Instantly share code, notes, and snippets.

@Shivam-316
Created November 17, 2020 17:35
Show Gist options
  • Save Shivam-316/7b03e40fe49a651c654f443a5ead5d9a to your computer and use it in GitHub Desktop.
Save Shivam-316/7b03e40fe49a651c654f443a5ead5d9a to your computer and use it in GitHub Desktop.
@tf.function
def train(input,target,enc_hidden):
loss__=0.0
with tf.GradientTape() as tape:
enc_output,enc_h,enc_c=encoder(input,enc_hidden)
enc_states=[enc_h,enc_c]
dec_input=tf.expand_dims(target[:,0],1)
for t in range(1,target.shape[1]):
dec_output,_,_=decoder(dec_input,enc_states)
loss__+=loss_fn(target[:,t],dec_output)
dec_input = tf.expand_dims(target[:, t], 1)
batch_loss=loss__/int(target.shape[1])
variables = encoder.trainable_variables + decoder.trainable_variables
gradients=tape.gradient(loss__,variables)
optimizer.apply_gradients(zip(gradients,variables))
return batch_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment