def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering | |
Args: | |
logits: logits distribution shape (vocabulary size) | |
top_k >0: keep only top k tokens with highest probability (top-k filtering). | |
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). | |
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) | |
""" | |
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear | |
top_k = min(top_k, logits.size(-1)) # Safety check | |
if top_k > 0: | |
# Remove all tokens with a probability less than the last token of the top-k | |
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
logits[indices_to_remove] = filter_value | |
if top_p > 0.0: | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
# Remove tokens with cumulative probability above the threshold | |
sorted_indices_to_remove = cumulative_probs > top_p | |
# Shift the indices to the right to keep also the first token above the threshold | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
logits[indices_to_remove] = filter_value | |
return logits | |
# Here is how to use this function for top-p sampling | |
temperature = 1.0 | |
top_k = 0 | |
top_p = 0.9 | |
# Get logits with a forward pass in our model (input is pre-defined) | |
logits = model(input) | |
# Keep only the last token predictions of the first batch item (batch size 1), apply a temperature coefficient and filter | |
logits = logits[0, -1, :] / temperature | |
filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) | |
# Sample from the filtered distribution | |
probabilities = F.softmax(filtered_logits, dim=-1) | |
next_token = torch.multinomial(probabilities, 1) |
This comment has been minimized.
This comment has been minimized.
I tried it out here and was unimpressed with sample quality compared to top_k=40, temp=.8 |
This comment has been minimized.
This comment has been minimized.
@mdda, I don't think I understand your issue. So I'm curious, what error was raised in your tests? @8enmann it's been working slightly better than top-40 in my tests (dialog generation) but the variance of my personal evaluation is quite high I must say. I usually use a temperature a bit lower: |
This comment has been minimized.
This comment has been minimized.
@thomwolf the paper suggested temperature 1.0, so that's what I'd been using (and I was getting an error using the original code, @mdda 's edit fixed it for me. Stack trace below. Note: batch size was 1, not sure if that matters. @yaroslavvb and I were talking about an automated way to tune these hyperparameters and thought about LAMBADA or even the likelihood of selecting "gold" word as they mention in the paper. Haven't tried it.
|
This comment has been minimized.
This comment has been minimized.
@8enmann Interesting, thanks! I would like to get to the root of this error if you're ok. What is the shape of your input logits tensor? |
This comment has been minimized.
This comment has been minimized.
torch.Size([1, 50257])
ᐧ
…On Thu, May 16, 2019 at 2:15 AM Thomas Wolf ***@***.***> wrote:
Interesting, I would like to get to the root of this error if you're ok.
What is the shape of your logits tensor?
Mine is torch.Size([50262]) in my current testing setup.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<https://gist.github.com/1a5a29f6962089e871b94cbd09daf317?email_source=notifications&email_token=AAHZJMHZGX734OP46ZWZFF3PVUQZPA5CNFSM4HK3TQP2YY3PNVWWK3TUL52HS4DFVNDWS43UINXW23LFNZ2KUY3PNVWWK3TUL5UWJTQAFSE2U#gistcomment-2918826>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAHZJMHXVFOJMIK3F6REBODPVUQZPANCNFSM4HK3TQPQ>
.
|
This comment has been minimized.
This comment has been minimized.
Here's an example based on huggingface run_gpt2
https://github.com/JasonBenn/duet/blob/master/generate.py
ᐧ
…On Thu, May 16, 2019 at 11:46 AM Ben Mann ***@***.***> wrote:
torch.Size([1, 50257])
ᐧ
On Thu, May 16, 2019 at 2:15 AM Thomas Wolf ***@***.***>
wrote:
> Interesting, I would like to get to the root of this error if you're ok.
> What is the shape of your logits tensor?
> Mine is torch.Size([50262]) in my current testing setup.
>
> —
> You are receiving this because you were mentioned.
> Reply to this email directly, view it on GitHub
> <https://gist.github.com/1a5a29f6962089e871b94cbd09daf317?email_source=notifications&email_token=AAHZJMHZGX734OP46ZWZFF3PVUQZPA5CNFSM4HK3TQP2YY3PNVWWK3TUL52HS4DFVNDWS43UINXW23LFNZ2KUY3PNVWWK3TUL5UWJTQAFSE2U#gistcomment-2918826>,
> or mute the thread
> <https://github.com/notifications/unsubscribe-auth/AAHZJMHXVFOJMIK3F6REBODPVUQZPANCNFSM4HK3TQPQ>
> .
>
|
This comment has been minimized.
This comment has been minimized.
Hi, added a PR about expending to num_samples > 1. |
This comment has been minimized.
This comment has been minimized.
I am getting this error here, plus can you please elaborate about that "input"? logits = model.forward(input) File "", line 1, in TypeError: forward() missing 1 required positional argument: 'mc_token_ids' |
This comment has been minimized.
This comment has been minimized.
You can refer in the run_generation.py in the Transformers to a full source code working on a real model, i.e. OpenAI GPT-2.
|
This comment has been minimized.
This comment has been minimized.
sample from the smallest set whose cumulative probability mass exceeds p for next words |
This comment has been minimized.
This comment has been minimized.
@thomwolf In the top_k_top_p_filtering function, it set the logit score to zero but doesn't change the probability distribution. |
This comment has been minimized.
This comment has been minimized.
Hi! After filtering the logits, they are converted to class probabilities via the call to |
This comment has been minimized.
This comment has been minimized.
Hello all! I have a more general question about nucleus sampling itself, maybe someone will be willing to clarify several things for me. |
This comment has been minimized.
This comment has been minimized.
@mdda |
This comment has been minimized.
This comment has been minimized.
How can I return multiple sampling sequences? Thanks! |
This comment has been minimized.
Line 24 :
indices_to_remove = sorted_indices[sorted_indices_to_remove]
does not seem to do what's intended, since the masking operation on the RHS seems to produce a list of indices fromsorted_indices
(but the shape is different from thelogits
that got sorted)I had to go with something like :
This is on PyTorch 1.1 (and 1.0,1 which I was using before I thought I must be going crazy)
HTH (please let me know if the above is also wrong...)