Skip to content

Instantly share code, notes, and snippets.

@tezansahu
Created February 12, 2022 14:21
Show Gist options
  • Save tezansahu/3dd6924557d7ed141db967fc66fe8ebb to your computer and use it in GitHub Desktop.
Save tezansahu/3dd6924557d7ed141db967fc66fe8ebb to your computer and use it in GitHub Desktop.
# The dataclass decorator is used to automatically generate special methods to classes,
# including __init__, __str__ and __repr__. It helps reduce some boilerplate code.
@dataclass
class MultimodalCollator:
tokenizer: AutoTokenizer
preprocessor: AutoFeatureExtractor
def tokenize_text(self, texts: List[str]):
encoded_text = self.tokenizer(
text=texts,
padding='longest',
max_length=24,
truncation=True,
return_tensors='pt',
return_token_type_ids=True,
return_attention_mask=True,
)
return {
"input_ids": encoded_text['input_ids'].squeeze(),
"token_type_ids": encoded_text['token_type_ids'].squeeze(),
"attention_mask": encoded_text['attention_mask'].squeeze(),
}
def preprocess_images(self, images: List[str]):
processed_images = self.preprocessor(
images=[Image.open(os.path.join("dataset", "images", image_id + ".png")).convert('RGB') for image_id in images],
return_tensors="pt",
)
return {
"pixel_values": processed_images['pixel_values'].squeeze(),
}
def __call__(self, raw_batch_dict):
return {
**self.tokenize_text(
raw_batch_dict['question']
if isinstance(raw_batch_dict, dict) else
[i['question'] for i in raw_batch_dict]
),
**self.preprocess_images(
raw_batch_dict['image_id']
if isinstance(raw_batch_dict, dict) else
[i['image_id'] for i in raw_batch_dict]
),
'labels': torch.tensor(
raw_batch_dict['label']
if isinstance(raw_batch_dict, dict) else
[i['label'] for i in raw_batch_dict],
dtype=torch.int64
),
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment