Skip to content

Instantly share code, notes, and snippets.

@thistleknot
Created January 14, 2024 03:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thistleknot/023f89b69dab8816b7fc01f2e7010e1c to your computer and use it in GitHub Desktop.
Save thistleknot/023f89b69dab8816b7fc01f2e7010e1c to your computer and use it in GitHub Desktop.
GPT Classify
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
import torch.nn.functional as F
# Load the GPT-2 model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
# Your question and prompt
question = "Is a bird a mammal?"
prompt = f"""
System:
Your role is to answer with a single character, Y for Yes, N for No.
{question}
Y or N?
Response:
"""
# Encode the prompt to a tensor
encoded_input = tokenizer.encode(prompt, return_tensors='pt')
# Get model predictions (logits)
with torch.no_grad():
outputs = model(encoded_input)
predictions = outputs.logits
# Extract logits for 'Y' and 'N'
logit_y = predictions[:, -1, tokenizer.encode('Y')[0]]
logit_n = predictions[:, -1, tokenizer.encode('N')[0]]
# Apply softmax to get probabilities
probs = F.softmax(torch.tensor([logit_y, logit_n]), dim=0)
prob_y = probs[0]
prob_n = probs[1]
print('prob y:', prob_y)
print('prob n:', prob_n)
# Decide the answer based on the probabilities
if prob_y > prob_n:
answer = 'Yes'
else:
answer = 'No'
print(f"Is a bird a mammal? {answer}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment