{{ message }}

Instantly share code, notes, and snippets.

# thomwolf/top-k-top-p.py

Last active Sep 16, 2021
Sample the next token from a probability distribution using top-k and/or nucleus (top-p) sampling
 def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (vocabulary size) top_k >0: keep only top k tokens with highest probability (top-k filtering). top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) """ assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[..., -1, None] logits[indices_to_remove] = filter_value if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = filter_value return logits # Here is how to use this function for top-p sampling temperature = 1.0 top_k = 0 top_p = 0.9 # Get logits with a forward pass in our model (input is pre-defined) logits = model(input) # Keep only the last token predictions of the first batch item (batch size 1), apply a temperature coefficient and filter logits = logits[0, -1, :] / temperature filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) # Sample from the filtered distribution probabilities = F.softmax(filtered_logits, dim=-1) next_token = torch.multinomial(probabilities, 1)

### mdda commented May 5, 2019

 Line 24 : `indices_to_remove = sorted_indices[sorted_indices_to_remove]` does not seem to do what's intended, since the masking operation on the RHS seems to produce a list of indices from `sorted_indices` (but the shape is different from the `logits` that got sorted) I had to go with something like : `````` indices_to_remove = torch.zeros_like(logits, dtype=torch.uint8).scatter_( dim=-1, index=sorted_indices, src=sorted_indices_to_remove ) `````` This is on PyTorch 1.1 (and 1.0,1 which I was using before I thought I must be going crazy) HTH (please let me know if the above is also wrong...)

### 8enmann commented May 15, 2019

 I tried it out here and was unimpressed with sample quality compared to top_k=40, temp=.8 Anyone have better luck?

### thomwolf commented May 16, 2019

 @mdda, I don't think I understand your issue. Why would the masking operation on the RHS produce a list of indices and not a LongTensor? `torch.sort` produces a tuple (Tensor, LongTensor), `sorted_indices_to_remove` is a ByteTensor. When we index the LongTensor with a ByteTensor we have another LongTensor with only the masked elements kept (so not the same size indeed, which is intended). We can then set the masked elements to -inf in the last indexing operation. Your solution is the same using a ByteTensor to mask instead of a LongTensor to index in the last operation so the results are identical (I tested it to compare the outputs). So I'm curious, what error was raised in your tests? @8enmann it's been working slightly better than top-40 in my tests (dialog generation) but the variance of my personal evaluation is quite high I must say. I usually use a temperature a bit lower: `0.7` and a `top_p` of `0.9`.

### 8enmann commented May 16, 2019 • edited

 @thomwolf the paper suggested temperature 1.0, so that's what I'd been using (and `top_p=.9`). Reducing the temperature is giving much better results! I was getting an error using the original code, @mdda 's edit fixed it for me. Stack trace below. Note: batch size was 1, not sure if that matters. @yaroslavvb and I were talking about an automated way to tune these hyperparameters and thought about LAMBADA or even the likelihood of selecting "gold" word as they mention in the paper. Haven't tried it. ``````Traceback (most recent call last): File "generate.py", line 131, in main() File "generate.py", line 64, in main softmax = hidden_to_softmax(model, pred_hid[-1], top_k=args.top_k, temperature=args.temperature, top_p=args.top_p) File "generate.py", line 90, in hidden_to_softmax logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) File "generate.py", line 127, in top_k_top_p_filtering logits[indices_to_remove] = filter_value IndexError: index 36955 is out of bounds for dimension 0 with size 1 ``````

### thomwolf commented May 16, 2019 • edited

 @8enmann Interesting, thanks! I would like to get to the root of this error if you're ok. What is the shape of your input logits tensor? Mine is `torch.Size()` in my current testing setup. [update] Oh yes, that's the problem, I'm filtering the logits with `logits = logits[0, -1, :] / temperature` and not what I showed here. Let's fix the gist for batch_size 1 for now then (that's the main use-case anyway).

### 8enmann commented May 16, 2019

 torch.Size([1, 50257]) ᐧ … On Thu, May 16, 2019 at 2:15 AM Thomas Wolf ***@***.***> wrote: Interesting, I would like to get to the root of this error if you're ok. What is the shape of your logits tensor? Mine is torch.Size() in my current testing setup. — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub , or mute the thread .

### 8enmann commented May 16, 2019

 Here's an example based on huggingface run_gpt2 https://github.com/JasonBenn/duet/blob/master/generate.py ᐧ … On Thu, May 16, 2019 at 11:46 AM Ben Mann ***@***.***> wrote: torch.Size([1, 50257]) ᐧ On Thu, May 16, 2019 at 2:15 AM Thomas Wolf ***@***.***> wrote: > Interesting, I would like to get to the root of this error if you're ok. > What is the shape of your logits tensor? > Mine is torch.Size() in my current testing setup. > > — > You are receiving this because you were mentioned. > Reply to this email directly, view it on GitHub > , > or mute the thread > > . >

### mataney commented Sep 25, 2019

 Hi, added a PR about expending to num_samples > 1. huggingface/transformers#1333 :)

### Ace-ezer commented Nov 20, 2019 • edited

 I am getting this error here, plus can you please elaborate about that "input"? logits = model.forward(input) Traceback (most recent call last): File "", line 1, in logits = model.forward(input) TypeError: forward() missing 1 required positional argument: 'mc_token_ids'

### TheEdoardo93 commented Dec 18, 2019

 You can refer in the run_generation.py in the Transformers to a full source code working on a real model, i.e. OpenAI GPT-2. I am getting this error here, plus can you please elaborate about that "input"? logits = model.forward(input) Traceback (most recent call last): File "", line 1, in logits = model.forward(input) TypeError: forward() missing 1 required positional argument: 'mc_token_ids'

### aj7tesh commented Feb 14, 2020

 sample from the smallest set whose cumulative probability mass exceeds p for next words what exactly this means? lets say I put p =0.9 then it will filter only those next tokens which have probability of > 0.9 or what ?

### chungyilinxrspace commented Feb 20, 2020

 @thomwolf Hi, I am recently learning the temperature sampling/ Nucleus sampling, And I read the paper: "The Curious Case of Neural Text Degeneration", they rescaled the original-distribution to a new-distribution, In the top_k_top_p_filtering function, it set the logit score to zero but doesn't change the probability distribution. Does "Change the probability distribution" is necessary for top-p sampling? Thank you ~

### tbazin commented Mar 3, 2020 • edited

 @chungyilinxrspace In the top_k_top_p_filtering function, it set the logit score to zero but doesn't change the probability distribution. Does "Change the probability distribution" is necessary for top-p sampling? Hi! TL;DR: The filtering function provided operates on the logits and not on the probabilities. After filtering the logits, they are converted to class probabilities via the call to `F.softmax`, which ensures both that the filtered classes have zero probability (since they have logit value `float("-inf)"`) and that the filtered probabilities define a proper, scaled, proability distribution. Hence the probability distribution is indeed "changed".

### nilinykh commented Jun 3, 2020

 Hello all! First, thank you for a very nice piece of code. I have a more general question about nucleus sampling itself, maybe someone will be willing to clarify several things for me. How do we choose k and p? As fas as I understand, every time we generate text, it will be different given that k and p are the same (or different). In other words, one cannot get a stable generate output (unlike when using greedy or beam search). Is there a good approximation of what values for these parameters could be? Or should it based solely on empirical observations for a particular problem? If the latter is the case, can anyone navigate me towards basic ideas on how changing k and/or p would affect generated output in general ?

### Hyman25 commented Oct 27, 2020

 @mdda Line24 exactly produce a list of indices. and your code helps.

### JiyangZhang commented Nov 19, 2020

 How can I return multiple sampling sequences? My understanding is run nucleus sampling for a whole sequence multiple times. Thanks!

### BenjaminWegener commented May 28, 2021

 https://github.com/BenjaminWegener/gpt2_torch_nucleus as working example

### tapdiego-amzn commented Aug 21, 2021

 Thank you for this code. Is it distributed under some Open Source license or are there otherwise any limitations on its use?