Skip to content

Instantly share code, notes, and snippets.

@enric1994
Created April 19, 2020 22:12
Show Gist options
  • Save enric1994/bfff235e82e3741ca66b4a441b8a0380 to your computer and use it in GitHub Desktop.
Save enric1994/bfff235e82e3741ca66b4a441b8a0380 to your computer and use it in GitHub Desktop.
Pytorch Dataloader boilerplate
from __future__ import print_function, division
import os
import torch
from skimage import io
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import json
from sklearn import preprocessing
from PIL import Image
class CountingDataset(Dataset):
def __init__(self, input_path, output_path, transform=None):
labels = []
image_paths =[]
labels_files = os.listdir(input_path)
for l in labels_files:
with open(os.path.join(input_path, l)) as f:
data=json.load(f)
image_name = data['global']['scene_name']
image_path = os.path.join(output_path, image_name, 'original', '00000000.png')
image_paths.append(image_path)
count = len(data['objects'])
labels.append(count)
self.dataset = list(zip(image_paths, labels))
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
image = Image.open(self.dataset[idx][0]).convert('RGB')
label_id = self.dataset[idx][1]
if self.transform:
image = self.transform(image)
return image, label_id
class DatasetFromSubset(Dataset):
def __init__(self, subset, transform=None):
self.subset = subset
self.transform = transform
def __getitem__(self, index):
x, y = self.subset[index]
if self.transform:
x = self.transform(x)
return x, y
def __len__(self):
return len(self.subset)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment