Skip to content

Instantly share code, notes, and snippets.

@lazuxd
Created July 19, 2021 18:36
Show Gist options
  • Save lazuxd/22c7766c973b603fb800b437ba90a8a9 to your computer and use it in GitHub Desktop.
Save lazuxd/22c7766c973b603fb800b437ba90a8a9 to your computer and use it in GitHub Desktop.
def sample_word(vocabulary: list, prob: np.ndarray, threshold: float) -> str:
# sample a word from the vocabulary according to 'prob'
# probability distribution (the softmax output of our model)
prob = prob.tolist()
vocab_prob = [[vocabulary[i], prob[i]] for i in range(len(prob))]
vocab_prob.sort(reverse=True, key=lambda e: e[1])
s = 0
for i in range(len(vocab_prob)):
if s > threshold:
vocab_prob[i][1] = 0
s += vocab_prob[i][1]
vocab = [w for w, p in vocab_prob]
prob = np.array([p/s for w, p in vocab_prob])
return np.random.choice(vocab, p=prob)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment