Skip to content

Instantly share code, notes, and snippets.

@lzqlzzq
Created December 4, 2023 05:57
Show Gist options
  • Save lzqlzzq/ca96a013cf40ce4561d7ad4121d372b1 to your computer and use it in GitHub Desktop.
Save lzqlzzq/ca96a013cf40ce4561d7ad4121d372b1 to your computer and use it in GitHub Desktop.
Temperature, top-k, top-p sampler with repetition_penalty
import torch
class Sampler:
def __init__(self,
top_k: int,
top_p: float,
temperature: float,
repetition_penalty: float):
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.repetition_penalty = repetition_penalty
def _temperature(self,
probs: torch.Tensor):
scaled_log_probs = torch.log(probs) / self.temperature
t_probs = torch.exp(scaled_log_probs)
return t_probs / t_probs.sum(dim=-1)
def _apply_repetition_penalty(self,
probs: torch.Tensor,
prev_tokens: torch.Tensor):
probs *= torch.ones_like(probs).scatter_(-1, prev_tokens, self.repetition_penalty)
return probs / probs.sum(dim=-1)
def _top_k(self,
probs: torch.Tensor):
_, k_idxs = torch.topk(probs, self.top_k)
probs *= torch.zeros_like(probs).scatter_(-1, k_idxs.long(), 1.)
return probs / probs.sum(dim=-1)
def _top_p(self,
probs: torch.Tensor):
sorted_logits, p_idxs = torch.sort(probs, dim=-1, descending=True)
mask = torch.cumsum(sorted_logits, dim=-1) <= self.top_p
probs *= torch.zeros_like(probs).scatter_(-1, p_idxs, mask.float())
return probs / probs.sum(dim=-1)
def __call__(self,
probs: torch.Tensor,
prev_tokens: torch.Tensor = None,
return_prob: bool = False):
t_probs = self._temperature(probs)
r_probs = self._apply_repetition_penalty(t_probs, prev_tokens) if prev_tokens != None else t_probs
k_probs = self._top_k(r_probs)
p_probs = self._top_p(k_probs)
return p_probs if return_prob else torch.multinomial(p_probs, num_samples=1).item()
if __name__ == '__main__':
sampler = Sampler(5, .8, 2.1, .9)
probs = torch.softmax(torch.randn((10)), dim=-1)
dummy_prev_tokens = torch.LongTensor([1, 3, 4])
print('Original distribution: \n', probs)
print('Distribution for sampling: \n', sampler(probs, dummy_prev_tokens, return_prob=True))
print('Sampling for 10 times: ')
for i in range(10):
print(sampler(probs, dummy_prev_tokens))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment