-
-
Save sravantit25/ba50f13af1f595a2ba098c8808f9348a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script is used to fetch the anniversary number from the Qxf2's work anniversary image. | |
For this it uses Pix2Struct, a pre-trained image-to-text model for extracting text from an image | |
and performing visual question answering. | |
""" | |
import logging | |
import torch | |
from PIL import Image | |
from transformers import Pix2StructForConditionalGeneration as psg | |
from transformers import Pix2StructProcessor as psp | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s %(levelname)s: %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
filename="anniversary_test.log", | |
filemode="a", | |
) | |
def instantiate_model_and_processor(): | |
""" | |
Instantiates the Pix2Struct model and its associated processor | |
Returns: | |
model (Pix2StructForConditionalGeneration): The pre-trained Pix2Struct model | |
processor (Pix2StructProcessor): The processor for handling inputs | |
""" | |
try: | |
model = psg.from_pretrained("google/pix2struct-docvqa-large") | |
processor = psp.from_pretrained("google/pix2struct-docvqa-large") | |
return model, processor | |
except Exception as error: | |
logging.error("Failed to instantiate model and processor: %s", error) | |
raise | |
def perform_vqa(image_path, question, model, processor): | |
""" | |
Performs visual question answering on the given image and question | |
Args: | |
image_path (str): The path of the image | |
question (str): The question to ask | |
model (Pix2StructForConditionalGeneration): The pre-trained Pix2Struct model | |
processor (Pix2StructProcessor): The processor for handling inputs | |
Returns: | |
answer (str): The generated answer to the question | |
""" | |
try: | |
image = Image.open(image_path) | |
inputs = processor(images=[image], text=[question], return_tensors="pt").to( | |
DEVICE | |
) | |
predictions = model.generate(**inputs, max_new_tokens=256) | |
answer = processor.decode(predictions[0], skip_special_tokens=True) | |
return answer | |
except FileNotFoundError: | |
logging.error("Image file not found: %s", image_path) | |
raise | |
except Exception as error: | |
logging.error("Failed to perform visual question answering: %s", error) | |
raise | |
def get_anniversary_year(image_path, model, processor): | |
""" | |
Fetches the number(of years) from the work anniversary image | |
Args: | |
image_path(str): The path of the image | |
Returns: | |
answer (str): The anniversary year | |
""" | |
question = "What is the number or numeric value that comes after the word Happy?" | |
try: | |
answer = perform_vqa(image_path, question, model, processor) | |
logging.info("Anniversary year for %s is %s", image_path, answer) | |
return answer | |
except Exception as error: | |
logging.error("Failed to get anniversary year: %s", error) | |
raise |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment