Skip to content

Instantly share code, notes, and snippets.

@francois-rozet
Created August 25, 2023 14:54
Show Gist options
  • Save francois-rozet/aa7aabc2f76774dabfce998b096d92be to your computer and use it in GitHub Desktop.
Save francois-rozet/aa7aabc2f76774dabfce998b096d92be to your computer and use it in GitHub Desktop.
Tiny ImageNet dataset
import numpy as np
import re
import torch
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from typing import Callable
class TinyImageNet(torch.utils.data.Dataset):
r"""Tiny ImageNet dataset
Download:
http://cs231n.stanford.edu/tiny-imagenet-200.zip
"""
def __init__(
self,
root: str,
split: str = 'train',
transform: Callable = None,
):
root = Path(root)
# Classes
with open(root / 'wnids.txt') as f:
classes = {tag: i for i, tag in enumerate(f.read().splitlines())}
with open(root / 'words.txt') as f:
descriptions = {}
for match in re.finditer(r'(\w+)\s+(\w.*)', f.read()):
tag, description = match.groups()
if tag in classes:
descriptions[classes[tag]] = description
self.classes = classes
self.descriptions = descriptions
# Files
images = []
labels = []
if split == 'train':
for subdir in (root / 'train').iterdir():
if subdir.is_dir():
for img in (subdir / 'images').glob('*.JPEG'):
images.append(img)
labels.append(classes[subdir.name])
elif split == 'val':
with open(root / 'val' / 'val_annotations.txt') as f:
for match in re.finditer(r'(\w+.JPEG)\s+(\w+)', f.read()):
img, tag = match.groups()
images.append(root / 'val' / 'images' / img)
labels.append(classes[tag])
elif split == 'test':
for img in (root / 'test' / 'images').glob('*.JPEG'):
images.append(img)
else:
raise
# Load
def img2arr(file):
return np.asarray(Image.open(file).convert('RGB'))
with ThreadPoolExecutor() as executor:
images = list(tqdm(executor.map(img2arr, images)))
self.images = np.stack(images)
self.labels = np.array(labels) if labels else None
# Transform
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, i: int):
img = Image.fromarray(self.images[i])
if self.transform:
img = self.transform(img)
if self.labels is None:
label = None
else:
label = self.labels[i]
return img, label
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment