Skip to content

Instantly share code, notes, and snippets.

@thesephist
Created September 4, 2023 20:23
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thesephist/2ab7d80e6bd94d16f051b9f8fc62b289 to your computer and use it in GitHub Desktop.
Save thesephist/2ab7d80e6bd94d16f051b9f8fc62b289 to your computer and use it in GitHub Desktop.
Code (most of it) for my GPT2 perplexities visualizer UI: https://twitter.com/thesephist/status/1617747154231259137
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel
ppl_model_name = 'gpt2-xl' if device == 'cuda' else 'gpt2'
ppl_tokenizer = GPT2Tokenizer.from_pretrained(ppl_model_name)
load_opts = {
'device_map': 'auto',
'torch_dtype': torch.float16,
} if torch.cuda.is_available() else {}
ppl_model = GPT2LMHeadModel.from_pretrained(ppl_model_name, **load_opts).to(device)
def perplexities(text: str, stride: int = 128):
tokenizer, model = ppl_tokenizer, ppl_model
def tokenize(text: str) -> torch.LongTensor:
return tokenizer(tokenizer.bos_token + text, return_tensors='pt').input_ids[0].to(device)
def token_list(tokens: torch.LongTensor) -> List[int]:
return tokenizer.batch_decode(tokens.unsqueeze(1))
max_length = model.config.n_positions
input_ids = tokenize(text).to(device).unsqueeze(0)
seq_len = input_ids.size(1)
top_k = 10
tokens = []
for begin_loc in range(0, max(1, seq_len - max_length + stride), stride):
end_loc = min(begin_loc + max_length, seq_len - 1)
span_input_ids = input_ids[:, begin_loc:end_loc]
target_ids = input_ids[:, begin_loc+1:end_loc+1]
with torch.no_grad():
outputs = model(span_input_ids, labels=target_ids)
logits = outputs.logits
log_probs = F.log_softmax(logits, dim=-1)
probs = F.softmax(logits, dim=-1)
target_log_probs = log_probs.gather(2, target_ids.unsqueeze(2)).squeeze(2)
target_probs = probs.gather(2, target_ids.unsqueeze(2)).squeeze(2)
greedy_log_probs, greedy_tokens = log_probs.topk(top_k, dim=2)
greedy_probs = torch.exp(greedy_log_probs)
for tok, predicted_toks, log_prob, prob in list(zip(
token_list(target_ids[0]),
[
zip(topk_log_probs, topk_probs, token_list(topk_tokens))
for topk_log_probs, topk_probs, topk_tokens
in zip(
greedy_log_probs[0].tolist(),
greedy_probs[0].tolist(),
greedy_tokens[0],
)
],
target_log_probs[0].tolist(),
target_probs[0].tolist(),
))[max_length - stride if begin_loc > 0 else 0:]:
tokens.append({
'token': tok,
'predicted_tokens': [{
'token': tok,
'log_prob': log_prob,
'prob': prob,
} for log_prob, prob, tok in predicted_toks],
'log_prob': log_prob,
'prob': prob,
})
return tokens
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment