Skip to content

Instantly share code, notes, and snippets.

@aryamanarora
Created February 14, 2023 01:10
Show Gist options
  • Save aryamanarora/350ba7001ed844394e17c53bf0e7b9fa to your computer and use it in GitHub Desktop.
Save aryamanarora/350ba7001ed844394e17c53bf0e7b9fa to your computer and use it in GitHub Desktop.
from transformers import AutoProcessor, BlipForConditionalGeneration
import requests
# load model
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
model = model.to(device)
# tokenize the image + caption we're interested in
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
caption = "two cats"
inputs = processor(image, caption, return_tensors="pt")
inputs.to(device)
# inference
out = model(**inputs)
# sanity check
print(out["decoder_logits"].shape)
# torch.Size([1, 4, 30524]) = [batch_size, tokens, vocab], i.e. what is the probability of the next token given this context
print(inputs["input_ids"])
# tensor([[ 101, 2048, 8870, 102]], device='cuda:0')
# get logits and tokens for calculating probs
logits = out["decoder_logits"]
tokens = inputs["input_ids"]
batches, num_tokens = logits.shape[0], logits.shape[1]
# YOU NEED TO DO SOFTMAX BC ITS LOGITS
softmax = torch.nn.LogSoftmax(dim=2)
logprobs = softmax(logits)
# calculate probs
for b in range(batches):
last_prob = 0.0
prob = 0.0
for i in range(1, num_tokens):
token = tokens[b][i]
token_prob = logprobs[b][i - 1][token]
print(f"{processor.decode(token):<20} {token_prob.exp():.8f}")
prob += token_prob
last_prob = prob - token_prob
print("-" * 31)
print(f"{'FINAL PROB (NO SEP)':<20} {last_prob.exp():.8f}")
print(f"{'FINAL PROB':<20} {prob.exp():.8f}")
# two 0.58799326
# cats 0.72933346
# [SEP] 0.00012775
# -------------------------------
# FINAL PROB (NO SEP) 0.42884299
# FINAL PROB 0.00005478
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment