Skip to content

Instantly share code, notes, and snippets.

@Shivam-316
Created November 17, 2020 17:51
Show Gist options
  • Save Shivam-316/82b075a9ac4278d8c0e7b4e8efe75a38 to your computer and use it in GitHub Desktop.
Save Shivam-316/82b075a9ac4278d8c0e7b4e8efe75a38 to your computer and use it in GitHub Desktop.
def predict(input):
hidden=[tf.zeros((1,512)),tf.zeros((1,512))]
_,enc_h,enc_c=encoder(input,hidden)
enc_states=[enc_h,enc_c]
result=[]
dec_input = tf.expand_dims(input[:,0], 0)
for t in range(input.shape[1]):
dec_output,_,_=decoder(dec_input,enc_states)
output_id=tf.math.argmax(dec_output[0],-1)
output_id=output_id[0].numpy()
if output_id == hindi_tokenizer.word_index['<eos>']:
return ' '.join(result)
dec_input = tf.expand_dims([output_id], 0)
result.append(hindi_tokenizer.index_word[output_id])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment