Skip to content

Instantly share code, notes, and snippets.

@sravantit25
Created June 20, 2023 11:36
Show Gist options
  • Save sravantit25/ba50f13af1f595a2ba098c8808f9348a to your computer and use it in GitHub Desktop.
Save sravantit25/ba50f13af1f595a2ba098c8808f9348a to your computer and use it in GitHub Desktop.
"""
This script is used to fetch the anniversary number from the Qxf2's work anniversary image.
For this it uses Pix2Struct, a pre-trained image-to-text model for extracting text from an image
and performing visual question answering.
"""
import logging
import torch
from PIL import Image
from transformers import Pix2StructForConditionalGeneration as psg
from transformers import Pix2StructProcessor as psp
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
filename="anniversary_test.log",
filemode="a",
)
def instantiate_model_and_processor():
"""
Instantiates the Pix2Struct model and its associated processor
Returns:
model (Pix2StructForConditionalGeneration): The pre-trained Pix2Struct model
processor (Pix2StructProcessor): The processor for handling inputs
"""
try:
model = psg.from_pretrained("google/pix2struct-docvqa-large")
processor = psp.from_pretrained("google/pix2struct-docvqa-large")
return model, processor
except Exception as error:
logging.error("Failed to instantiate model and processor: %s", error)
raise
def perform_vqa(image_path, question, model, processor):
"""
Performs visual question answering on the given image and question
Args:
image_path (str): The path of the image
question (str): The question to ask
model (Pix2StructForConditionalGeneration): The pre-trained Pix2Struct model
processor (Pix2StructProcessor): The processor for handling inputs
Returns:
answer (str): The generated answer to the question
"""
try:
image = Image.open(image_path)
inputs = processor(images=[image], text=[question], return_tensors="pt").to(
DEVICE
)
predictions = model.generate(**inputs, max_new_tokens=256)
answer = processor.decode(predictions[0], skip_special_tokens=True)
return answer
except FileNotFoundError:
logging.error("Image file not found: %s", image_path)
raise
except Exception as error:
logging.error("Failed to perform visual question answering: %s", error)
raise
def get_anniversary_year(image_path, model, processor):
"""
Fetches the number(of years) from the work anniversary image
Args:
image_path(str): The path of the image
Returns:
answer (str): The anniversary year
"""
question = "What is the number or numeric value that comes after the word Happy?"
try:
answer = perform_vqa(image_path, question, model, processor)
logging.info("Anniversary year for %s is %s", image_path, answer)
return answer
except Exception as error:
logging.error("Failed to get anniversary year: %s", error)
raise
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment