Last active
June 30, 2022 15:11
-
-
Save kastnerkyle/81939e0843f2538e26af9cb2b5bf4ed2 to your computer and use it in GitHub Desktop.
My own take on a plug and play setup for typical sampling from Meister et. al. "Typical Decoding for Natural Language Generation". Added top k by typicality for now
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def typical_top_k_filtering(logits, top_k=0, top_p=0.0, temperature=1.0, min_tokens_to_keep=1, filter_value=-1E12): | |
""" Filter a distribution of logits using typicality, with optional top-k and/or nucleus (top-p) filtering | |
Meister et. al. https://arxiv.org/abs/2202.00666 | |
Args: | |
logits: logits distribution shape (..., vocabulary size) | |
top_k >0: keep top k tokens with highest prob (top-k filtering). | |
top_p >0.0: keep the top p tokens which compose cumulative probability mass top_p (nucleus filtering). | |
min_tokens_to_keep >=1: always keep at least this many tokens through the top_p / nucleus sampling | |
""" | |
# https://arxiv.org/abs/2202.00666 | |
# based on hugging face impl but added top k | |
# https://github.com/cimeister/typical-sampling/commit/0f24c9409dc078ed23982197e8af1439093eedd3#diff-cde731a000ec723e7224c8aed4ffdedc9751f5599fe0a859c5c65d0c5d94891dR249 | |
# changed some of the scatter logic to looping + stacking due to spooky threaded cuda errors, need to CUDA_NONBLOCKING=1 to fix | |
# typical decoding | |
scores = logits | |
mass = top_p if top_p > 0.0 else 1.0 | |
# calculate entropy | |
log_p = torch.nn.functional.log_softmax(scores, dim=-1) | |
p = torch.exp(log_p) | |
ent = -(p * log_p).sum(-1, keepdim=True) | |
# shift and sort | |
# abs(I() - H()) | |
# I() is -log(p()) from eq 5 | |
# so overall we see -log(p()) - ent | |
# orig code was ((-ent) - log_p) | |
shifted_scores = torch.abs(-log_p - ent) | |
# most typical (0) to least typical (high abs value) | |
_, sorted_indices = torch.sort(shifted_scores, descending=False, stable=True) | |
top_k = min(top_k, scores.size(-1) - 1) # safety check that top k is not too large | |
if top_k > 0: | |
topkval = torch.topk(torch.max(shifted_scores) - shifted_scores, top_k)[0][..., -1, None] | |
indices_to_remove = (torch.max(shifted_scores) - shifted_scores) < topkval | |
scores[indices_to_remove] = filter_value | |
sorted_scores = scores.gather(-1, sorted_indices) | |
cumulative_probs = sorted_scores.softmax(dim=-1).cumsum(dim=-1) | |
# Remove tokens once cumulative probability above the threshold | |
sorted_indices_to_remove = cumulative_probs > mass | |
sorted_indices_to_remove = sorted_indices_to_remove.long() | |
if min_tokens_to_keep > 1: | |
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) | |
sorted_indices_to_remove[..., : min_tokens_to_keep - 1] = 0 | |
# Shift the indices to the right to keep also the first token above the threshold | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
sorted_indices = torch.tensor(sorted_indices.cpu().data.numpy()) | |
shp = scores.shape | |
# not great cuda errors on gather calls here, rewrote to a "slow" version | |
scores_red = scores.reshape((-1, shp[-1])) | |
sorted_indices_red = sorted_indices.reshape((-1, shp[-1])) | |
sorted_indices_to_remove_red = sorted_indices_to_remove.reshape((-1, shp[-1])) | |
for i in range(shp[0]): | |
scores_red[i][sorted_indices_red[i]] = scores_red[i][sorted_indices_red[i]] * (1. - sorted_indices_to_remove_red[i]) + sorted_indices_to_remove_red[i] * filter_value | |
scores = scores_red.reshape(shp) | |
return scores |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment