Instantly share code, notes, and snippets.

# thomwolf/top-k-top-p.py

Last active May 14, 2024 00:20
Show Gist options
• Save thomwolf/1a5a29f6962089e871b94cbd09daf317 to your computer and use it in GitHub Desktop.
Sample the next token from a probability distribution using top-k and/or nucleus (top-p) sampling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (vocabulary size) top_k >0: keep only top k tokens with highest probability (top-k filtering). top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) """ assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = filter_value return logits # Here is how to use this function for top-p sampling temperature = 1.0 top_k = 0 top_p = 0.9 # Get logits with a forward pass in our model (input is pre-defined) logits = model(input) # Keep only the last token predictions of the first batch item (batch size 1), apply a temperature coefficient and filter logits = logits[0, -1, :] / temperature filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) # Sample from the filtered distribution probabilities = F.softmax(filtered_logits, dim=-1) next_token = torch.multinomial(probabilities, 1)

### Hyman25 commented Oct 27, 2020

@mdda
Line24 exactly produce a list of indices. and your code helps.

### JiyangZhang commented Nov 19, 2020

How can I return multiple sampling sequences?
My understanding is run nucleus sampling for a whole sequence multiple times.

Thanks!

### tapdiego-amzn commented Aug 21, 2021

Thank you for this code. Is it distributed under some Open Source license or are there otherwise any limitations on its use?

### kushalj001 commented May 1, 2022

`````` # Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
``````

What is the significance of these lines? Cannot get my head around them.
Thanks

### umiswing commented Jul 20, 2022

@thomwolf
Hello! I'm trying to modify the code to support batch size grater than one. I get some problem to make the top_p support 2d input. I didn't find the appropriate pytorch api to index the 2d tensor(line 26~27) in the code and implement it with for loop, which is too slow. Could you provide some suggestions about the implementation?

### nicofirst1 commented Aug 8, 2022

@umiswing I'm also looking for a batched version of this, did you find anything?

### umiswing commented Aug 8, 2022

@umiswing I'm also looking for a batched version of this, did you find anything?

@nicofirst1 I modify it to batch version. But I didn't do much test for it. I hope it can help.

### not-hermione commented Feb 16, 2023 • edited

`````` # Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
``````

What is the significance of these lines? Cannot get my head around them. Thanks

is it to keep at least one value?

### LeeSinLiang commented Apr 24, 2023 • edited

@not-hermione
Consider this example:

```x = torch.arange(5,0,-1)
cumulative_probs = torch.cumsum(x, dim=0) # tensor([ 5,  9, 12, 14, 15])
sorted_indices_to_remove = cumulative_probs > 13 # tensor([False, False, False,  True,  True])```

We want to create a boolean mask called `sorted_indices_to_remove` to identify which indices in `cumulative_probs` need to be removed. Specifically, we want to remove indices where the corresponding value in cumulative_probs is greater than 13.

Notice the index corresponding to value 12 is also marked as True in `sorted_indices_to_remove`, which we don't want to remove.

To address this issue, we use the following two lines of code:

```sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0 ```

These 2 lines of code shift the values in `sorted_indices_to_remove` to the right by 1 along the last dimension and then set the first value along the last dimension to `False`.
This ensures the index corresponding to value 12 in `cumulative_probs` is not marked as `True` in `sorted_indices_to_remove`.

### not-hermione commented Apr 24, 2023 • edited

@LeeSinLiang
does anyone have this error? `RuntimeError: scatter(): Expected self.dtype to be equal to src.dtype`?
```indices_to_remove = torch.zeros_like(logits, dtype=sorted_indices_to_remove.dtype).scatter_(