Skip to content

Instantly share code, notes, and snippets.

Last active April 24, 2023 16:19
Show Gist options
  • Save ehofesmann/a1e73d02941463b554c7f198ecf4488a to your computer and use it in GitHub Desktop.
Save ehofesmann/a1e73d02941463b554c7f198ecf4488a to your computer and use it in GitHub Desktop.
import torchvision
from import DataLoader
import os
# A class to process a FiftyOne dataset in DETR format
class FiftyOneCocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, samples, folder, processor, label_field="ground_truth"):
ann_file = os.path.join(folder, "labels.json")
img_folder = os.path.join(folder, "data")
if not os.path.exists(ann_file) or not os.path.exists(img_folder):
print("Existing COCO annotation not found, exporting to %s" % folder)
samples.export(folder, dataset_type=fo.types.COCODetectionDataset, export_media="symlink", label_field=label_field, classes=samples.distinct(label_field+".detections.label"))
super(FiftyOneCocoDetection, self).__init__(img_folder, ann_file)
self.processor = processor
def __getitem__(self, idx):
# read in PIL image and target in COCO format
# feel free to add data augmentation here before passing them to the next step
img, target = super(FiftyOneCocoDetection, self).__getitem__(idx)
# preprocess image and target (converting target to DETR format, resizing + normalization of both image and target)
image_id = self.ids[idx]
target = {'image_id': image_id, 'annotations': target}
encoding = self.processor(images=img, annotations=target, return_tensors="pt")
pixel_values = encoding["pixel_values"].squeeze() # remove batch dimension
target = encoding["labels"][0] # remove batch dimension
return pixel_values, target
def collate_fn_generator(processor):
def collate_fn(batch):
pixel_values = [item[0] for item in batch]
encoding = processor.pad(pixel_values, return_tensors="pt")
labels = [item[1] for item in batch]
batch = {}
batch['pixel_values'] = encoding['pixel_values']
batch['pixel_mask'] = encoding['pixel_mask']
batch['labels'] = labels
return batch
return collate_fn
def create_data_loaders(train_view, val_view, processor, train_dir="train", val_dir="val"):
train_dataset = FiftyOneCocoDetection(train_view, train_dir, processor)
val_dataset = FiftyOneCocoDetection(val_view, val_dir, processor)
collate_fn = collate_fn_generator(processor)
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=4, shuffle=True)
val_dataloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=2)
return train_dataloader, val_dataloader
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment