Skip to content

Instantly share code, notes, and snippets.

@vered1986
Last active August 16, 2019 20:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vered1986/3f390192bf16852a028f9bdd3e4b7d26 to your computer and use it in GitHub Desktop.
Save vered1986/3f390192bf16852a028f9bdd3e4b7d26 to your computer and use it in GitHub Desktop.
def generate_most_probable(lm, index2word):
"""
Generates a string, taking the most probable word at each time step.
:param lm - the language model: a function that gets a string and returns a distribution on the next word
:param index2word - a mapping from the index of a word in the vocabulary to the word itself
"""
generated_sentence = '<s>'
curr_token = None
while curr_token != '</s>':
curr_distribution = lm(generated_sentence) # vector of probabilities
sorted_by_probability = np.argsort(curr_distribution)
curr_token = index2word[int(sorted_by_probability[-1])] # last token is the most probable
generated_sentence += ' ' + curr_token
return generated_sentence
generated_str = generate_most_probable(stupid_lm, vocab)
print(generated_str)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment