Created
April 2, 2021 06:55
-
-
Save saahiluppal/4e04d05b3499660797fb6a0d452d46a1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import BertTokenizer | |
from PIL import Image | |
import argparse | |
import os | |
from datasets import coco, utils | |
from configuration import Config | |
parser = argparse.ArgumentParser(description='Image Captioning') | |
parser.add_argument('--path', type=str, help='path to image', required=True) | |
parser.add_argument('--v', type=str, help='version', default='v3') | |
args = parser.parse_args() | |
directory_path = args.path | |
version = args.v | |
if version == 'v1': | |
model = torch.hub.load('saahiluppal/catr', 'v1', pretrained=True) | |
elif version == 'v2': | |
model = torch.hub.load('saahiluppal/catr', 'v2', pretrained=True) | |
elif version == 'v3': | |
model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True) | |
else: | |
raise NotImplementedError('Version not implemented') | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
config = Config() | |
start_token = tokenizer.convert_tokens_to_ids(tokenizer._cls_token) | |
end_token = tokenizer.convert_tokens_to_ids(tokenizer._sep_token) | |
if not os.path.exists(directory_path): | |
print("Directory Not Found...") | |
exit() | |
images_in_dir = os.listdir(directory_path) | |
print("Starting...") | |
for image_path in images_in_dir: | |
image = Image.open(os.path.join(directory_path, image_path)) | |
image = coco.val_transform(image) | |
image = image.unsqueeze(0) | |
def create_caption_and_mask(start_token, max_length): | |
caption_template = torch.zeros((1, max_length), dtype=torch.long) | |
mask_template = torch.ones((1, max_length), dtype=torch.bool) | |
caption_template[:, 0] = start_token | |
mask_template[:, 0] = False | |
return caption_template, mask_template | |
caption, cap_mask = create_caption_and_mask( | |
start_token, config.max_position_embeddings) | |
@torch.no_grad() | |
def evaluate(): | |
model.eval() | |
for i in range(config.max_position_embeddings - 1): | |
predictions = model(image, caption, cap_mask) | |
predictions = predictions[:, i, :] | |
predicted_id = torch.argmax(predictions, axis=-1) | |
if predicted_id[0] == 102: | |
return caption | |
caption[:, i+1] = predicted_id[0] | |
cap_mask[:, i+1] = False | |
return caption | |
output = evaluate() | |
result = tokenizer.decode(output[0].tolist(), skip_special_tokens=True) | |
#result = tokenizer.decode(output[0], skip_special_tokens=True) | |
print(image_path, ":", result.capitalize()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment