Skip to content

Instantly share code, notes, and snippets.

@vered1986
Last active August 16, 2019 23:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vered1986/ee16a0333b761f8313b5490c03012ce4 to your computer and use it in GitHub Desktop.
Save vered1986/ee16a0333b761f8313b5490c03012ce4 to your computer and use it in GitHub Desktop.
def generate_sample(lm, index2word, max_tokens=25):
"""
Generates a string, sample a word from the distribution at each time step.
:param lm - the language model
:param index2word - a mapping from the index of a word in the vocabulary to the word itself
"""
generated_sentence = '<s>'
generated_tokens = 0
curr_token = None
while curr_token != '</s>' and generated_tokens < max_tokens:
curr_distribution = lm(generated_sentence) # vector of probabilities
selected_index = np.random.choice(range(len(vocab)), p=curr_distribution)
curr_token = index2word[int(selected_index)]
generated_sentence += ' ' + curr_token
generated_tokens += 1
return generated_sentence
generated_str = generate_sample(stupid_lm, vocab)
print(generated_str)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment