Skip to content

Instantly share code, notes, and snippets.

@alexpaden
Created September 10, 2023 17:21
Show Gist options
  • Save alexpaden/517aaa2195b9bbc15a979d09202bb1c0 to your computer and use it in GitHub Desktop.
Save alexpaden/517aaa2195b9bbc15a979d09202bb1c0 to your computer and use it in GitHub Desktop.
Binary Meme Classifier
from transformers import AutoModelForImageClassification, AutoProcessor
import torch
import requests
from PIL import Image
import numpy as np
from io import BytesIO
def download_image(url):
response = requests.get(url)
print(f"Status Code: {response.status_code}")
print(f"Content-Type: {response.headers.get('Content-Type')}")
if response.status_code == 200 and response.headers.get('Content-Type').startswith('image/'): # Success
return Image.open(BytesIO(response.content))
else:
print(f"Failed to download image from {url}")
return None
def classify_meme(image_url, model_name="Hrishikesh332/autotrain-meme-classification-42897109437"):
# Download the image from the URL
image = download_image(image_url)
if image is None:
return "Failed to download image."
# Convert the image to "RGB" if it's not
if image.mode != 'RGB':
image = image.convert('RGB')
# Convert the PIL image to a NumPy array
image_np = np.array(image)
# Initialize the processor and model
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
# Preprocess the image and make prediction
inputs = processor(images=image_np, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = torch.argmax(logits, dim=1).item()
# Map the predicted index to the label
label = "Meme" if predicted_class_idx == 0 else "Not Meme"
return label
if __name__ == "__main__":
image_url = input("Enter the image URL: ")
label = classify_meme(image_url)
print(f"The image is classified as: {label}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment