Skip to content

Instantly share code, notes, and snippets.

@saharNooby
Created April 28, 2023 13:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saharNooby/e1d871e93522d9c50c5a6fa59f356ba9 to your computer and use it in GitHub Desktop.
Save saharNooby/e1d871e93522d9c50c5a6fa59f356ba9 to your computer and use it in GitHub Desktop.
# USAGE EXAMPLE
logits = llm(...) # Get raw logits from an LLM
logits = tail_free_sampling(z=0.95) # Cut off logits in the tail
token = sample(logits, temperature=1.0) # Do your usual sampling with temp/top-p
def tail_free_sampling(logits: torch.Tensor, z: float = 0.95, mask_value: float = -float('inf')) -> torch.Tensor:
"""
See https://www.trentonbricken.com/Tail-Free-Sampling/
Code copied from https://github.com/finetunej/transformers/blob/c83109932f4592b871ec4c60326df3b4173b021a/src/transformers/generation_logits_process.py#L243-L284
:param logits: Logits.
:param z: Hyperparameter for tail-free sampling.
:param mask_value: Tokens that should be excluded from sampling would have their logit set to this value.
:return: Masked logits.
"""
assert len(logits.shape) == 1, str(logits.shape)
# numpy sort is faster than PyTorch (5 ms vs 7 ms)
logits_np = logits.detach().cpu().numpy()
sorted_indices_np = np.argsort(logits_np, kind='quicksort')
sorted_indices_np = np.ascontiguousarray(np.flip(sorted_indices_np))
sorted_logits_np = logits_np[sorted_indices_np]
sorted_indices = torch.tensor(sorted_indices_np, device=logits.device)
sorted_logits = torch.tensor(sorted_logits_np, device=logits.device)
d = sorted_logits.softmax(dim=-1)
d = d[1:] - d[:-1]
d = d[1:] - d[:-1]
d = d.abs()
d = d / d.sum(dim=-1).item()
cumulative_probs = d.cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = torch.zeros(sorted_indices.shape, dtype=torch.bool, device=logits.device)
sorted_indices_to_remove[:-2] = (cumulative_probs > z)[:]
# Always keep the first token
sorted_indices_to_remove[0] = 0
# Always remove two last tokens -- they should have negligible probability anyway
sorted_indices_to_remove[len(sorted_indices_to_remove) - 1] = 1
sorted_indices_to_remove[len(sorted_indices_to_remove) - 2] = 1
# Scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
return logits.masked_fill(indices_to_remove, mask_value)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment