Skip to content

Instantly share code, notes, and snippets.

@lazuxd
Created July 20, 2021 13:10
Show Gist options
  • Save lazuxd/8938c603eff2d3b7f7bbfa7fcd3ce10a to your computer and use it in GitHub Desktop.
Save lazuxd/8938c603eff2d3b7f7bbfa7fcd3ce10a to your computer and use it in GitHub Desktop.
def predict_next(self, sentence: str,
threshold: float = 0.9) -> str:
# predict the next part of the sentence given as parameter
self.reset_state(1)
for word in sentence.strip():
x = words2onehot(self.vocab, [word])
y_hat = self(x)
s = ''
while True:
word = sample_word(self.vocab,
tf.reshape(y_hat, (-1,)).numpy(),
threshold)
if word == EOS:
break
s += word
x = words2onehot(self.vocab, [word])
y_hat = self(x)
return s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment