Skip to content

Instantly share code, notes, and snippets.

@tezansahu
Created February 12, 2022 22:50
Show Gist options
  • Save tezansahu/52b2f2105953627d267eba31fc3e51d7 to your computer and use it in GitHub Desktop.
Save tezansahu/52b2f2105953627d267eba31fc3e51d7 to your computer and use it in GitHub Desktop.
def loadAnswerSpace() -> List[str]:
with open(os.path.join("dataset", "answer_space.txt")) as f:
answer_space = f.read().splitlines()
return answer_space
def tokenizeQuestion(text_encoder, question, device) -> Dict:
tokenizer = transformers.AutoTokenizer.from_pretrained(text_encoder)
encoded_text = tokenizer(
text=[question],
padding='longest',
max_length=24,
truncation=True,
return_tensors='pt',
return_token_type_ids=True,
return_attention_mask=True,
)
return {
"input_ids": encoded_text['input_ids'].to(device),
"token_type_ids": encoded_text['token_type_ids'].to(device),
"attention_mask": encoded_text['attention_mask'].to(device),
}
def featurizeImage(image_encoder, img_path, device) -> Dict:
featurizer = transformers.AutoFeatureExtractor.from_pretrained(image_encoder)
processed_images = featurizer(
images=[Image.open(img_path).convert('RGB')],
return_tensors="pt",
)
return {
"pixel_values": processed_images['pixel_values'].to(device),
}
question = "What is present on the hanger?"
img_path = "dataset/images/image100.png"
# Load the vocabulary of all answers
answer_space = loadAnswerSpace()
# Tokenize the question & featurize the image
question = question.lower().replace("?", "").strip() # Remove the question mark (if present) & extra spaces before tokenizing
tokenized_question = tokenizeQuestion("bert-base-uncased", question, device)
featurized_img = featurizeImage("google/vit-base-patch16-224-in21k", img_path, device)
# Load the model checkpoint (for 5 epochs, final checkpoint should be checkpoint-1500)
model = MultimodalVQAModel(
pretrained_text_name="bert-base-uncased",
pretrained_image_name="google/vit-base-patch16-224-in21k",
num_labels=len(answer_space),
intermediate_dims=512
)
checkpoint = os.path.join("checkpoint", "checkpoint-1500", "pytorch_model.bin")
model.load_state_dict(torch.load(checkpoint))
model.to(device)
model.eval()
# Obtain the prediction from the model
input_ids = tokenized_question["input_ids"].to(device)
token_type_ids = tokenized_question["token_type_ids"].to(device)
attention_mask = tokenized_question["attention_mask"].to(device)
pixel_values = featurized_img["pixel_values"].to(device)
output = model(input_ids, pixel_values, attention_mask, token_type_ids)
# Obtain the answer from the answer space
preds = output["logits"].argmax(axis=-1).cpu().numpy()
answer = answer_space[preds[0]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment