Skip to content

Instantly share code, notes, and snippets.

@brandon-lockaby
Last active May 7, 2024 22:33
Show Gist options
  • Save brandon-lockaby/0e357aecfe51bbd53a4c41457c29d484 to your computer and use it in GitHub Desktop.
Save brandon-lockaby/0e357aecfe51bbd53a4c41457c29d484 to your computer and use it in GitHub Desktop.
MultiClassifier
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
DEV = "cuda"
# I want a base model and this is instruct-tuned, but it will fit on my gpu
model_path = "microsoft/Phi-3-mini-128k-instruct"
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map=DEV,
torch_dtype="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
class MultiClassifier():
def __init__(self, dev, model, tokenizer, prompt, class_names):
self.__dict__.update(locals())
# tokenize the given prompt for reuse on every classify() call
self.prompt_ids = tokenizer.encode(self.prompt, return_tensors="pt").to(self.dev)
# get kv cache to also reuse on every classify() call
self.kv_cache = self.model(self.prompt_ids, return_dict=True).past_key_values
# and keep these token ids
self.yes = " yes"
self.no = " no"
self.yes_id = torch.tensor(tokenizer.encode(self.yes, add_special_tokens=False)[-1]).to(self.dev).unsqueeze(0).unsqueeze(0)
self.no_id = torch.tensor(tokenizer.encode(self.no, add_special_tokens=False)[-1]).to(self.dev).unsqueeze(0).unsqueeze(0)
def classify(self, held_out_example, return_probs=False):
output_class_list = []
output_probs = {}
kv_cache = self.kv_cache
# iterate through all the class names
new_text = held_out_example
for class_name in self.class_names:
# generate a token following the class name marker
prompt_ids = tokenizer.encode(f"{new_text}\n{class_name}:", add_special_tokens=False, return_tensors="pt").to(self.dev)
attention_mask = torch.ones(len(kv_cache) + len(prompt_ids), device=self.dev)
outputs = self.model(prompt_ids, past_key_values=kv_cache, attention_mask=attention_mask, return_dict=True)
kv_cache = outputs.past_key_values
# just keep the two logits we're interested in
logits = torch.tensor([outputs.logits[-1,-1,self.yes_id], outputs.logits[-1,-1,self.no_id]], device=self.dev)
# and convert to probabilities
probs = torch.nn.functional.softmax(logits, dim=-1)
yes_prob = probs[0].item()
no_prob = probs[1].item()
# results get
if yes_prob >= no_prob:
output_class_list.append(class_name)
new_text = self.yes
else:
new_text = self.no
if return_probs:
output_probs[class_name] = {"yes": yes_prob, "no": no_prob}
return (output_class_list, output_probs) if return_probs else output_class_list
prompt = """Text: I ate an apple and then a few oranges.
Apples: yes
Oranges: yes
Text: Do you sell chocolate oranges?
Apples: no
Oranges: yes
Text: I want something red to eat.
Apples: yes
Oranges: no
Text: Orange you glad I didn't say apple?
Apples: yes
Oranges: yes
Text: I hate oranges and I hate apples!
Apples: yes
Oranges: yes
Text: My car is orange
Apples: no
Oranges: no
Text: Red!
Apples: no
Oranges: no
Text: These can sometimes be red.
Apples: no
Oranges: no
Text: orange
Apples: no
Oranges: yes
Text: What are you eating?
Apples: no
Oranges: no
Text: """
class_names = ["Apples", "Oranges"]
classifier = MultiClassifier(DEV, model, tokenizer, prompt, class_names)
def test(text):
result = classifier.classify(text, return_probs=True)
print(f"\n{text}\n\t{result}")
test("You can't squeeze ketchup from a banana.")
test("Do you like apple pie?")
test("Too bad. I baked an orange pie.")
test("DO NOT give me apple pie.")
test("red")
test("These can sometimes be red.")
test("orangey")
test("No apples and no oranges")
test("What are you eating?")
test("Orples")
test("What about a-p-p-l-e")
# Output
#You can't squeeze ketchup from a banana.
#([], {'Apples': {'yes': 0.0011695101857185364, 'no': 0.9988304972648621}, 'Oranges': {'yes': 0.005220125894993544, 'no': 0.9947799444198608}})
#Do you like apple pie?
#(['Apples'], {'Apples': {'yes': 0.9997387528419495, 'no': 0.00026119028916582465}, 'Oranges': {'yes': 6.144174221844878e-06, 'no': 0.9999938011169434}})
#Too bad. I baked an orange pie.
#(['Oranges'], {'Apples': {'yes': 0.0010322310263291001, 'no': 0.9989677667617798}, 'Oranges': {'yes': 0.9999938011169434, 'no': 6.144174221844878e-06}})
#DO NOT give me apple pie.
#(['Apples'], {'Apples': {'yes': 0.9740425944328308, 'no': 0.02595735713839531}, 'Oranges': {'yes': 2.6729447100137804e-08, 'no': 1.0}})
#red
#([], {'Apples': {'yes': 0.02595735713839531, 'no': 0.9740425944328308}, 'Oranges': {'yes': 0.2018132209777832, 'no': 0.7981867790222168}})
#These can sometimes be red.
#([], {'Apples': {'yes': 0.0019267346942797303, 'no': 0.9980732202529907}, 'Oranges': {'yes': 0.0534033328294754, 'no': 0.9465966820716858}})
#orangey
#(['Oranges'], {'Apples': {'yes': 0.00026119028916582465, 'no': 0.9997387528419495}, 'Oranges': {'yes': 0.9890130758285522, 'no': 0.01098694372922182}})
#No apples and no oranges
#([], {'Apples': {'yes': 0.0008040859247557819, 'no': 0.9991958737373352}, 'Oranges': {'yes': 0.0024726232513785362, 'no': 0.9975274205207825}})
#What are you eating?
#([], {'Apples': {'yes': 0.00048785717808641493, 'no': 0.9995121955871582}, 'Oranges': {'yes': 0.0011695101857185364, 'no': 0.9988304972648621}})
#Orples
#([], {'Apples': {'yes': 3.120191104244441e-05, 'no': 0.9999687671661377}, 'Oranges': {'yes': 0.007577240467071533, 'no': 0.9924227595329285}})
#What about a-p-p-l-e
#(['Apples'], {'Apples': {'yes': 0.9046505093574524, 'no': 0.09534946084022522}, 'Oranges': {'yes': 0.0035936026833951473, 'no': 0.9964063763618469}})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment