Skip to content

Instantly share code, notes, and snippets.

@saahiluppal
Created April 2, 2021 06:55
Show Gist options
  • Save saahiluppal/4e04d05b3499660797fb6a0d452d46a1 to your computer and use it in GitHub Desktop.
Save saahiluppal/4e04d05b3499660797fb6a0d452d46a1 to your computer and use it in GitHub Desktop.
import torch
from transformers import BertTokenizer
from PIL import Image
import argparse
import os
from datasets import coco, utils
from configuration import Config
parser = argparse.ArgumentParser(description='Image Captioning')
parser.add_argument('--path', type=str, help='path to image', required=True)
parser.add_argument('--v', type=str, help='version', default='v3')
args = parser.parse_args()
directory_path = args.path
version = args.v
if version == 'v1':
model = torch.hub.load('saahiluppal/catr', 'v1', pretrained=True)
elif version == 'v2':
model = torch.hub.load('saahiluppal/catr', 'v2', pretrained=True)
elif version == 'v3':
model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True)
else:
raise NotImplementedError('Version not implemented')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
config = Config()
start_token = tokenizer.convert_tokens_to_ids(tokenizer._cls_token)
end_token = tokenizer.convert_tokens_to_ids(tokenizer._sep_token)
if not os.path.exists(directory_path):
print("Directory Not Found...")
exit()
images_in_dir = os.listdir(directory_path)
print("Starting...")
for image_path in images_in_dir:
image = Image.open(os.path.join(directory_path, image_path))
image = coco.val_transform(image)
image = image.unsqueeze(0)
def create_caption_and_mask(start_token, max_length):
caption_template = torch.zeros((1, max_length), dtype=torch.long)
mask_template = torch.ones((1, max_length), dtype=torch.bool)
caption_template[:, 0] = start_token
mask_template[:, 0] = False
return caption_template, mask_template
caption, cap_mask = create_caption_and_mask(
start_token, config.max_position_embeddings)
@torch.no_grad()
def evaluate():
model.eval()
for i in range(config.max_position_embeddings - 1):
predictions = model(image, caption, cap_mask)
predictions = predictions[:, i, :]
predicted_id = torch.argmax(predictions, axis=-1)
if predicted_id[0] == 102:
return caption
caption[:, i+1] = predicted_id[0]
cap_mask[:, i+1] = False
return caption
output = evaluate()
result = tokenizer.decode(output[0].tolist(), skip_special_tokens=True)
#result = tokenizer.decode(output[0], skip_special_tokens=True)
print(image_path, ":", result.capitalize())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment