Skip to content

Instantly share code, notes, and snippets.

@shuternay
Created March 13, 2019 21:10
Show Gist options
  • Save shuternay/7c3d16397913d9c103bea47f575cd475 to your computer and use it in GitHub Desktop.
Save shuternay/7c3d16397913d9c103bea47f575cd475 to your computer and use it in GitHub Desktop.
Custom dataset
import os
from PIL import Image
class CachedPilLoader:
def __init__(self):
self.cache = {}
def __call__(self, path):
if path not in self.cache:
with open(path, 'rb') as f:
img = Image.open(f)
self.cache[path] = img.convert('RGB')
return self.cache[path]
loader = CachedPilLoader()
def make_dataset(dir):
images = []
dir = os.path.expanduser(dir)
d = os.path.join(dir, 'images')
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
# if fname.endswith('.jpeg'):
path = os.path.join(root, fname)
images.append(path)
return images
class ValDataset:
def __init__(self, root, loader, transform=None):
self.classes = train_dataset.dataset.classes
self.class_to_idx = train_dataset.dataset.class_to_idx
self.loader = loader
self.transform = transform
self.samples = make_dataset(root)
self.targets = {}
with open(os.path.join(root, 'val_annotations.txt')) as fp:
for line in fp:
name, cls, *_ = line.split()
self.targets[name.lower()] = self.class_to_idx[cls]
def __getitem__(self, index):
path = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
target = self.targets[path.rsplit('/', 1)[-1].lower()]
return sample, target
def __len__(self):
return len(self.samples)
val_dataset = ValDataset('tiny-imagenet-200/val', loader=loader, transform=transform_test)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment