-
-
Save ehofesmann/a1e73d02941463b554c7f198ecf4488a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torchvision | |
from torch.utils.data import DataLoader | |
import os | |
# A class to process a FiftyOne dataset in DETR format | |
class FiftyOneCocoDetection(torchvision.datasets.CocoDetection): | |
def __init__(self, samples, folder, processor, label_field="ground_truth"): | |
ann_file = os.path.join(folder, "labels.json") | |
img_folder = os.path.join(folder, "data") | |
if not os.path.exists(ann_file) or not os.path.exists(img_folder): | |
print("Existing COCO annotation not found, exporting to %s" % folder) | |
samples.export(folder, dataset_type=fo.types.COCODetectionDataset, export_media="symlink", label_field=label_field, classes=samples.distinct(label_field+".detections.label")) | |
super(FiftyOneCocoDetection, self).__init__(img_folder, ann_file) | |
self.processor = processor | |
def __getitem__(self, idx): | |
# read in PIL image and target in COCO format | |
# feel free to add data augmentation here before passing them to the next step | |
img, target = super(FiftyOneCocoDetection, self).__getitem__(idx) | |
# preprocess image and target (converting target to DETR format, resizing + normalization of both image and target) | |
image_id = self.ids[idx] | |
target = {'image_id': image_id, 'annotations': target} | |
encoding = self.processor(images=img, annotations=target, return_tensors="pt") | |
pixel_values = encoding["pixel_values"].squeeze() # remove batch dimension | |
target = encoding["labels"][0] # remove batch dimension | |
return pixel_values, target | |
def collate_fn_generator(processor): | |
def collate_fn(batch): | |
pixel_values = [item[0] for item in batch] | |
encoding = processor.pad(pixel_values, return_tensors="pt") | |
labels = [item[1] for item in batch] | |
batch = {} | |
batch['pixel_values'] = encoding['pixel_values'] | |
batch['pixel_mask'] = encoding['pixel_mask'] | |
batch['labels'] = labels | |
return batch | |
return collate_fn | |
def create_data_loaders(train_view, val_view, processor, train_dir="train", val_dir="val"): | |
train_dataset = FiftyOneCocoDetection(train_view, train_dir, processor) | |
val_dataset = FiftyOneCocoDetection(val_view, val_dir, processor) | |
collate_fn = collate_fn_generator(processor) | |
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=4, shuffle=True) | |
val_dataloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=2) | |
return train_dataloader, val_dataloader |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment