Skip to content

Instantly share code, notes, and snippets.

@lazuxd
Created July 1, 2021 12:32
Show Gist options
  • Save lazuxd/c55b7f30a6312838cb3bf90264fcadb9 to your computer and use it in GitHub Desktop.
Save lazuxd/c55b7f30a6312838cb3bf90264fcadb9 to your computer and use it in GitHub Desktop.
def predict_next(self, sentence: str) -> str:
# predict the next part of the sentence given as parameter
a = np.zeros((1, self.a_size))
for word in sentence.strip().split():
if word not in vocabulary:
word = UNK
x = words2onehot(self.vocab, [word])
a, y_hat = self(a, x)
s = ''
while True:
word = sample_word(self.vocab, tf.reshape(y_hat, (-1,)))
if word == EOS:
break
s += ' '+word
x = words2onehot(self.vocab, [word])
a, y_hat = self(a, x)
return s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment