Skip to content

Instantly share code, notes, and snippets.

@vered1986
Last active January 8, 2021 22:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vered1986/1f6ae30a87dfd080cd28bbee1cbc1464 to your computer and use it in GitHub Desktop.
Save vered1986/1f6ae30a87dfd080cd28bbee1cbc1464 to your computer and use it in GitHub Desktop.
from scipy.special import softmax
def generate_sample_top_k(lm, index2word, k=5, max_tokens=25):
"""
Generates a string, sample a word from the top k probable words in the distribution at each time step.
:param lm - the language model
:param index2word - a mapping from the index of a word in the vocabulary to the word itself
:param k - how many words to keep in the distribution
"""
generated_sentence = '<s>'
curr_token = None
generated_tokens = 0
while curr_token != '</s>' and generated_tokens < max_tokens:
curr_distribution = lm(generated_sentence) # vector of probabilities
sorted_by_probability = np.argsort(curr_distribution) # sort by probability
top_k_indices = sorted_by_probability[-(k+1):] # keep the top k words
top_k = [curr_distribution[i] if i in set(top_k_indices) else 0.0 for i in range(len(vocab))]
# normalize to make it a probability distribution again
top_k = softmax(top_k)
selected_index = np.random.choice(range(len(vocab)), p=top_k)
curr_token = index2word[int(selected_index)]
generated_sentence += ' ' + curr_token
generated_tokens += 1
return generated_sentence
generated_str = generate_sample_top_k(stupid_lm, vocab)
print(generated_str)
@vered1986
Copy link
Author

Thanks to Saptarshi Sengupta for the bug fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment