Skip to content

Instantly share code, notes, and snippets.

@vincentclaes
Last active September 27, 2022 10:59
Show Gist options
  • Save vincentclaes/58790701df8f9673e4a27ee687d952e9 to your computer and use it in GitHub Desktop.
Save vincentclaes/58790701df8f9673e4a27ee687d952e9 to your computer and use it in GitHub Desktop.
Describe a snippet that predicts an emoji using pre-trained CLIP model.
# install these dependencies
# pip install torch transformers pillow
# import the dependencies
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import torch
# emoji images: https://public-assets-vincent-claes.s3.eu-west-1.amazonaws.com/emoji-precitor/emojis.zip
path_to_emoji_folder = "<path to the folder with images of emojis>"
# read images
emojis_as_images = [Image.open(f"{path_to_emoji_folder}/{i}.png") for i in range(31)]
# provide text
text = "provide-some-text"
# load model and processor
checkpoint = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(checkpoint)
processor = CLIPProcessor.from_pretrained(checkpoint)
# process inputs and make a prediction
inputs = processor(text=text, images=emojis_as_images, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
# we want the probability for each emoji per sentence.
logits_per_text = outputs.logits_per_text
# we take the softmax to get the label probabilities.
prob = logits_per_text.softmax(dim=1)
# find back the label, which is the position in the list
# of images we feed the processor.
label = torch.argmax(prob).item()
# print the label of the emoji that best describes the tweet.
print(label) # 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment