Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Sample the next token from a probability distribution using top-k and/or nucleus (top-p) sampling
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k >0: keep only top k tokens with highest probability (top-k filtering).
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
"""
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
# Here is how to use this function for top-p sampling
temperature = 1.0
top_k = 0
top_p = 0.9
# Get logits with a forward pass in our model (input is pre-defined)
logits = model(input)
# Keep only the last token predictions of the first batch item (batch size 1), apply a temperature coefficient and filter
logits = logits[0, -1, :] / temperature
filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
# Sample from the filtered distribution
probabilities = F.softmax(filtered_logits, dim=-1)
next_token = torch.multinomial(probabilities, 1)
@mdda

This comment has been minimized.

Copy link

commented May 5, 2019

Line 24 : indices_to_remove = sorted_indices[sorted_indices_to_remove] does not seem to do what's intended, since the masking operation on the RHS seems to produce a list of indices from sorted_indices (but the shape is different from the logits that got sorted)

I had to go with something like :

    indices_to_remove = torch.zeros_like(logits, dtype=torch.uint8).scatter_(
        dim=-1, index=sorted_indices, src=sorted_indices_to_remove )

This is on PyTorch 1.1 (and 1.0,1 which I was using before I thought I must be going crazy)

HTH (please let me know if the above is also wrong...)

@8enmann

This comment has been minimized.

Copy link

commented May 15, 2019

I tried it out here and was unimpressed with sample quality compared to top_k=40, temp=.8
Anyone have better luck?

@thomwolf

This comment has been minimized.

Copy link
Owner Author

commented May 16, 2019

@mdda, I don't think I understand your issue.
Why would the masking operation on the RHS produce a list of indices and not a LongTensor?
torch.sort produces a tuple (Tensor, LongTensor), sorted_indices_to_remove is a ByteTensor. When we index the LongTensor with a ByteTensor we have another LongTensor with only the masked elements kept (so not the same size indeed, which is intended).
We can then set the masked elements to -inf in the last indexing operation.
Your solution is the same using a ByteTensor to mask instead of a LongTensor to index in the last operation so the results are identical (I tested it to compare the outputs).

So I'm curious, what error was raised in your tests?

@8enmann it's been working slightly better than top-40 in my tests (dialog generation) but the variance of my personal evaluation is quite high I must say. I usually use a temperature a bit lower: 0.7 and a top_p of 0.9.

@8enmann

This comment has been minimized.

Copy link

commented May 16, 2019

@thomwolf the paper suggested temperature 1.0, so that's what I'd been using (and top_p=.9). Reducing the temperature is giving much better results!

I was getting an error using the original code, @mdda 's edit fixed it for me. Stack trace below. Note: batch size was 1, not sure if that matters.

@yaroslavvb and I were talking about an automated way to tune these hyperparameters and thought about LAMBADA or even the likelihood of selecting "gold" word as they mention in the paper. Haven't tried it.

Traceback (most recent call last):
  File "generate.py", line 131, in <module>
    main()
  File "generate.py", line 64, in main
    softmax = hidden_to_softmax(model, pred_hid[-1], top_k=args.top_k, temperature=args.temperature, top_p=args.top_p)
  File "generate.py", line 90, in hidden_to_softmax
    logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
  File "generate.py", line 127, in top_k_top_p_filtering
    logits[indices_to_remove] = filter_value
IndexError: index 36955 is out of bounds for dimension 0 with size 1
@thomwolf

This comment has been minimized.

Copy link
Owner Author

commented May 16, 2019

@8enmann Interesting, thanks! I would like to get to the root of this error if you're ok. What is the shape of your input logits tensor?
Mine is torch.Size([50262]) in my current testing setup.
[update] Oh yes, that's the problem, I'm filtering the logits with logits = logits[0, -1, :] / temperature and not what I showed here.
Let's fix the gist for batch_size 1 for now then (that's the main use-case anyway).

@8enmann

This comment has been minimized.

Copy link

commented May 16, 2019

@8enmann

This comment has been minimized.

Copy link

commented May 16, 2019

@mataney

This comment has been minimized.

Copy link

commented Sep 25, 2019

Hi, added a PR about expending to num_samples > 1.
huggingface/transformers#1333
:)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.