Created
February 12, 2022 22:50
-
-
Save tezansahu/52b2f2105953627d267eba31fc3e51d7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def loadAnswerSpace() -> List[str]: | |
with open(os.path.join("dataset", "answer_space.txt")) as f: | |
answer_space = f.read().splitlines() | |
return answer_space | |
def tokenizeQuestion(text_encoder, question, device) -> Dict: | |
tokenizer = transformers.AutoTokenizer.from_pretrained(text_encoder) | |
encoded_text = tokenizer( | |
text=[question], | |
padding='longest', | |
max_length=24, | |
truncation=True, | |
return_tensors='pt', | |
return_token_type_ids=True, | |
return_attention_mask=True, | |
) | |
return { | |
"input_ids": encoded_text['input_ids'].to(device), | |
"token_type_ids": encoded_text['token_type_ids'].to(device), | |
"attention_mask": encoded_text['attention_mask'].to(device), | |
} | |
def featurizeImage(image_encoder, img_path, device) -> Dict: | |
featurizer = transformers.AutoFeatureExtractor.from_pretrained(image_encoder) | |
processed_images = featurizer( | |
images=[Image.open(img_path).convert('RGB')], | |
return_tensors="pt", | |
) | |
return { | |
"pixel_values": processed_images['pixel_values'].to(device), | |
} | |
question = "What is present on the hanger?" | |
img_path = "dataset/images/image100.png" | |
# Load the vocabulary of all answers | |
answer_space = loadAnswerSpace() | |
# Tokenize the question & featurize the image | |
question = question.lower().replace("?", "").strip() # Remove the question mark (if present) & extra spaces before tokenizing | |
tokenized_question = tokenizeQuestion("bert-base-uncased", question, device) | |
featurized_img = featurizeImage("google/vit-base-patch16-224-in21k", img_path, device) | |
# Load the model checkpoint (for 5 epochs, final checkpoint should be checkpoint-1500) | |
model = MultimodalVQAModel( | |
pretrained_text_name="bert-base-uncased", | |
pretrained_image_name="google/vit-base-patch16-224-in21k", | |
num_labels=len(answer_space), | |
intermediate_dims=512 | |
) | |
checkpoint = os.path.join("checkpoint", "checkpoint-1500", "pytorch_model.bin") | |
model.load_state_dict(torch.load(checkpoint)) | |
model.to(device) | |
model.eval() | |
# Obtain the prediction from the model | |
input_ids = tokenized_question["input_ids"].to(device) | |
token_type_ids = tokenized_question["token_type_ids"].to(device) | |
attention_mask = tokenized_question["attention_mask"].to(device) | |
pixel_values = featurized_img["pixel_values"].to(device) | |
output = model(input_ids, pixel_values, attention_mask, token_type_ids) | |
# Obtain the answer from the answer space | |
preds = output["logits"].argmax(axis=-1).cpu().numpy() | |
answer = answer_space[preds[0]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment