Skip to content

Instantly share code, notes, and snippets.

@kastnerkyle
Last active June 30, 2022 15:11
Show Gist options
  • Save kastnerkyle/81939e0843f2538e26af9cb2b5bf4ed2 to your computer and use it in GitHub Desktop.
Save kastnerkyle/81939e0843f2538e26af9cb2b5bf4ed2 to your computer and use it in GitHub Desktop.
My own take on a plug and play setup for typical sampling from Meister et. al. "Typical Decoding for Natural Language Generation". Added top k by typicality for now
def typical_top_k_filtering(logits, top_k=0, top_p=0.0, temperature=1.0, min_tokens_to_keep=1, filter_value=-1E12):
""" Filter a distribution of logits using typicality, with optional top-k and/or nucleus (top-p) filtering
Meister et. al. https://arxiv.org/abs/2202.00666
Args:
logits: logits distribution shape (..., vocabulary size)
top_k >0: keep top k tokens with highest prob (top-k filtering).
top_p >0.0: keep the top p tokens which compose cumulative probability mass top_p (nucleus filtering).
min_tokens_to_keep >=1: always keep at least this many tokens through the top_p / nucleus sampling
"""
# https://arxiv.org/abs/2202.00666
# based on hugging face impl but added top k
# https://github.com/cimeister/typical-sampling/commit/0f24c9409dc078ed23982197e8af1439093eedd3#diff-cde731a000ec723e7224c8aed4ffdedc9751f5599fe0a859c5c65d0c5d94891dR249
# changed some of the scatter logic to looping + stacking due to spooky threaded cuda errors, need to CUDA_NONBLOCKING=1 to fix
# typical decoding
scores = logits
mass = top_p if top_p > 0.0 else 1.0
# calculate entropy
log_p = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(log_p)
ent = -(p * log_p).sum(-1, keepdim=True)
# shift and sort
# abs(I() - H())
# I() is -log(p()) from eq 5
# so overall we see -log(p()) - ent
# orig code was ((-ent) - log_p)
shifted_scores = torch.abs(-log_p - ent)
# most typical (0) to least typical (high abs value)
_, sorted_indices = torch.sort(shifted_scores, descending=False, stable=True)
top_k = min(top_k, scores.size(-1) - 1) # safety check that top k is not too large
if top_k > 0:
topkval = torch.topk(torch.max(shifted_scores) - shifted_scores, top_k)[0][..., -1, None]
indices_to_remove = (torch.max(shifted_scores) - shifted_scores) < topkval
scores[indices_to_remove] = filter_value
sorted_scores = scores.gather(-1, sorted_indices)
cumulative_probs = sorted_scores.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens once cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > mass
sorted_indices_to_remove = sorted_indices_to_remove.long()
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : min_tokens_to_keep - 1] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
sorted_indices = torch.tensor(sorted_indices.cpu().data.numpy())
shp = scores.shape
# not great cuda errors on gather calls here, rewrote to a "slow" version
scores_red = scores.reshape((-1, shp[-1]))
sorted_indices_red = sorted_indices.reshape((-1, shp[-1]))
sorted_indices_to_remove_red = sorted_indices_to_remove.reshape((-1, shp[-1]))
for i in range(shp[0]):
scores_red[i][sorted_indices_red[i]] = scores_red[i][sorted_indices_red[i]] * (1. - sorted_indices_to_remove_red[i]) + sorted_indices_to_remove_red[i] * filter_value
scores = scores_red.reshape(shp)
return scores
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment