Skip to content

Instantly share code, notes, and snippets.

@LeeSinLiang
Created September 1, 2024 06:06
Show Gist options
  • Save LeeSinLiang/43a42ca33107a88281a4090d0ecff538 to your computer and use it in GitHub Desktop.
Save LeeSinLiang/43a42ca33107a88281a4090d0ecff538 to your computer and use it in GitHub Desktop.
Top K and Top P filtering. Used for microGPT inference.
def top_k_top_p_filter(logits, top_k: int = 0, top_p: float = 0.0):
if top_k > 0:
filter = torch.topk(logits, min(top_k, logits.size(-1)))[0]
logits[logits < filter[:, [-1]]] = float('-inf')
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1)
filter = cumulative_probs > top_p
filter[..., 1:] = filter[..., :-1].clone()
filter[..., 0] = 0
indices_to_remove = filter.scatter(1, sorted_indices, filter)
logits[indices_to_remove] = float('-inf')
return logits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment