Last active
January 10, 2025 07:45
-
-
Save thomwolf/1a5a29f6962089e871b94cbd09daf317 to your computer and use it in GitHub Desktop.
Sample the next token from a probability distribution using top-k and/or nucleus (top-p) sampling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering | |
Args: | |
logits: logits distribution shape (vocabulary size) | |
top_k >0: keep only top k tokens with highest probability (top-k filtering). | |
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). | |
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) | |
""" | |
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear | |
top_k = min(top_k, logits.size(-1)) # Safety check | |
if top_k > 0: | |
# Remove all tokens with a probability less than the last token of the top-k | |
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
logits[indices_to_remove] = filter_value | |
if top_p > 0.0: | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
# Remove tokens with cumulative probability above the threshold | |
sorted_indices_to_remove = cumulative_probs > top_p | |
# Shift the indices to the right to keep also the first token above the threshold | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
logits[indices_to_remove] = filter_value | |
return logits | |
# Here is how to use this function for top-p sampling | |
temperature = 1.0 | |
top_k = 0 | |
top_p = 0.9 | |
# Get logits with a forward pass in our model (input is pre-defined) | |
logits = model(input) | |
# Keep only the last token predictions of the first batch item (batch size 1), apply a temperature coefficient and filter | |
logits = logits[0, -1, :] / temperature | |
filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) | |
# Sample from the filtered distribution | |
probabilities = F.softmax(filtered_logits, dim=-1) | |
next_token = torch.multinomial(probabilities, 1) |
@LeeSinLiang
thanks alot for your time and answer.
does anyone have this error? RuntimeError: scatter(): Expected self.dtype to be equal to src.dtype
?
I changed the line to be
indices_to_remove = torch.zeros_like(logits, dtype=sorted_indices_to_remove.dtype).scatter_(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove )
such that it works now.
Note that this implementation does not just take the top_p probability mass. It also includes the probability mass of the token that straddles the top_p boundary. Here is a (numpy, not pytorch) implementation which always samples exactly from the top_p probability mass: https://gist.github.com/calvinmccarter/eaa9ee398606352e6e1df4b50e62881c .
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@not-hermione
Consider this example:
We want to create a boolean mask called
sorted_indices_to_remove
to identify which indices incumulative_probs
need to be removed. Specifically, we want to remove indices where the corresponding value in cumulative_probs is greater than 13.Notice the index corresponding to value 12 is also marked as True in
sorted_indices_to_remove
, which we don't want to remove.To address this issue, we use the following two lines of code:
These 2 lines of code shift the values in
sorted_indices_to_remove
to the right by 1 along the last dimension and then set the first value along the last dimension toFalse
.This ensures the index corresponding to value 12 in
cumulative_probs
is not marked asTrue
insorted_indices_to_remove
.