Skip to content

Instantly share code, notes, and snippets.

@thistleknot
Last active January 14, 2024 04:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thistleknot/83f17cf8fe9b917c6edc0923fa04d71a to your computer and use it in GitHub Desktop.
Save thistleknot/83f17cf8fe9b917c6edc0923fa04d71a to your computer and use it in GitHub Desktop.
Gpt-neo Classify
from transformers import GPT2Tokenizer, GPTNeoForCausalLM
import torch
import torch.nn.functional as F
# Load the GPT-Neo 1.3B model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
# Your question and prompt
question = "Is a bird a mammal?"
prompt = f"""
System:
Your role is to answer with a single character, Y for Yes, N for No, and ? for I don't know.
{question}
Y, N, or ?
Response:
"""
# Encode the prompt to a tensor
encoded_input = tokenizer.encode(prompt, return_tensors='pt')
# Get model predictions (logits)
with torch.no_grad():
outputs = model(encoded_input)
predictions = outputs.logits
# Extract logits for 'Y', 'N', and '?'
logit_y = predictions[:, -1, tokenizer.encode('Y')[0]]
logit_n = predictions[:, -1, tokenizer.encode('N')[0]]
logit_q = predictions[:, -1, tokenizer.encode('?')[0]]
# Apply softmax to get probabilities
probs = F.softmax(torch.tensor([logit_y, logit_n, logit_q]), dim=0)
# Identify the maximum probability and its index
max_prob, max_index = torch.max(probs, 0)
# Map the index to the corresponding answer
answers = {0: 'Yes', 1: 'No', 2: 'I don\'t know'}
selected_answer = answers[max_index.item()]
print(f'Probability Yes: {probs[0].item()}')
print(f'Probability No: {probs[1].item()}')
print(f'Probability I don\'t know: {probs[2].item()}')
print(f'Is a bird a mammal? {selected_answer}')
@thistleknot
Copy link
Author

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment