Created
March 13, 2019 21:10
-
-
Save shuternay/7c3d16397913d9c103bea47f575cd475 to your computer and use it in GitHub Desktop.
Custom dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from PIL import Image | |
class CachedPilLoader: | |
def __init__(self): | |
self.cache = {} | |
def __call__(self, path): | |
if path not in self.cache: | |
with open(path, 'rb') as f: | |
img = Image.open(f) | |
self.cache[path] = img.convert('RGB') | |
return self.cache[path] | |
loader = CachedPilLoader() | |
def make_dataset(dir): | |
images = [] | |
dir = os.path.expanduser(dir) | |
d = os.path.join(dir, 'images') | |
for root, _, fnames in sorted(os.walk(d)): | |
for fname in sorted(fnames): | |
# if fname.endswith('.jpeg'): | |
path = os.path.join(root, fname) | |
images.append(path) | |
return images | |
class ValDataset: | |
def __init__(self, root, loader, transform=None): | |
self.classes = train_dataset.dataset.classes | |
self.class_to_idx = train_dataset.dataset.class_to_idx | |
self.loader = loader | |
self.transform = transform | |
self.samples = make_dataset(root) | |
self.targets = {} | |
with open(os.path.join(root, 'val_annotations.txt')) as fp: | |
for line in fp: | |
name, cls, *_ = line.split() | |
self.targets[name.lower()] = self.class_to_idx[cls] | |
def __getitem__(self, index): | |
path = self.samples[index] | |
sample = self.loader(path) | |
if self.transform is not None: | |
sample = self.transform(sample) | |
target = self.targets[path.rsplit('/', 1)[-1].lower()] | |
return sample, target | |
def __len__(self): | |
return len(self.samples) | |
val_dataset = ValDataset('tiny-imagenet-200/val', loader=loader, transform=transform_test) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment