Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Create PyTorch datasets and dataset loaders for a subset of CIFAR10 classes.
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import Dataset, DataLoader
import numpy as np
# Transformations
RC = transforms.RandomCrop(32, padding=4)
RHF = transforms.RandomHorizontalFlip()
RVF = transforms.RandomVerticalFlip()
NRM = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
TT = transforms.ToTensor()
TPIL = transforms.ToPILImage()
# Transforms object for trainset with augmentation
transform_with_aug = transforms.Compose([TPIL, RC, RHF, TT, NRM])
# Transforms object for testset with NO augmentation
transform_no_aug = transforms.Compose([TT, NRM])
# Downloading/Louding CIFAR10 data
trainset = CIFAR10(root='./data', train=True , download=True)#, transform = transform_with_aug)
testset = CIFAR10(root='./data', train=False, download=True)#, transform = transform_no_aug)
classDict = {'plane':0, 'car':1, 'bird':2, 'cat':3, 'deer':4, 'dog':5, 'frog':6, 'horse':7, 'ship':8, 'truck':9}
# Separating trainset/testset data/label
x_train = trainset.train_data
x_test = testset.test_data
y_train = trainset.train_labels
y_test = testset.test_labels
# Define a function to separate CIFAR classes by class index
def get_class_i(x, y, i):
"""
x: trainset.train_data or testset.test_data
y: trainset.train_labels or testset.test_labels
i: class label, a number between 0 to 9
return: x_i
"""
# Convert to a numpy array
y = np.array(y)
# Locate position of labels that equal to i
pos_i = np.argwhere(y == i)
# Convert the result into a 1-D list
pos_i = list(pos_i[:,0])
# Collect all data that match the desired label
x_i = [x[j] for j in pos_i]
return x_i
class DatasetMaker(Dataset):
def __init__(self, datasets, transformFunc = transform_no_aug):
"""
datasets: a list of get_class_i outputs, i.e. a list of list of images for selected classes
"""
self.datasets = datasets
self.lengths = [len(d) for d in self.datasets]
self.transformFunc = transformFunc
def __getitem__(self, i):
class_label, index_wrt_class = self.index_of_which_bin(self.lengths, i)
img = self.datasets[class_label][index_wrt_class]
img = self.transformFunc(img)
return img, class_label
def __len__(self):
return sum(self.lengths)
def index_of_which_bin(self, bin_sizes, absolute_index, verbose=False):
"""
Given the absolute index, returns which bin it falls in and which element of that bin it corresponds to.
"""
# Which class/bin does i fall into?
accum = np.add.accumulate(bin_sizes)
if verbose:
print("accum =", accum)
bin_index = len(np.argwhere(accum <= absolute_index))
if verbose:
print("class_label =", bin_index)
# Which element of the fallent class/bin does i correspond to?
index_wrt_class = absolute_index - np.insert(accum, 0, 0)[bin_index]
if verbose:
print("index_wrt_class =", index_wrt_class)
return bin_index, index_wrt_class
# ================== Usage ================== #
# Let's choose cats (class 3 of CIFAR) and dogs (class 5 of CIFAR) as trainset/testset
cat_dog_trainset = \
DatasetMaker(
[get_class_i(x_train, y_train, classDict['cat']), get_class_i(x_train, y_train, classDict['dog'])],
transform_with_aug
)
cat_dog_testset = \
DatasetMaker(
[get_class_i(x_test , y_test , classDict['cat']), get_class_i(x_test , y_test , classDict['dog'])],
transform_no_aug
)
kwargs = {'num_workers': 2, 'pin_memory': False}
# Create datasetLoaders from trainset and testset
trainsetLoader = DataLoader(cat_dog_trainset, batch_size=64, shuffle=True , **kwargs)
testsetLoader = DataLoader(cat_dog_testset , batch_size=64, shuffle=False, **kwargs)
@Miladiouss

This comment has been minimized.

Copy link
Owner Author

@Miladiouss Miladiouss commented Sep 11, 2018

To test the performance of a neural net at the beginning stages, you don't need all 10 CIFAR-10 classes; 2 is enough and is a lot faster to train. But not any two, the hard pairs such as [cat, dog], [car, truck], and [deer, horse] due to their similarities. In many cases, one can tell the difference between airplanes and horses by comparing measuring blueness of an image! So, I wrote this code to build trainset/testset and trainsetLoader/testsetLoader for a subset of CIFAR-10 classes. If you find this helpful, please give it a star .

@JaeDukSeo

This comment has been minimized.

Copy link

@JaeDukSeo JaeDukSeo commented May 18, 2019

thanks for this!

@hyeongminoh

This comment has been minimized.

Copy link

@hyeongminoh hyeongminoh commented May 26, 2019

thank you!
But I got porblem
AttributeError: 'CIFAR10' object has no attribute 'train_data'
how can i fix it?

@agoel10

This comment has been minimized.

Copy link

@agoel10 agoel10 commented May 31, 2019

Thanks for this! Got a small error using this code. Need to convert the output images from numpy to PIL

@SeoHyeong

This comment has been minimized.

Copy link

@SeoHyeong SeoHyeong commented Jul 11, 2019

@hyeongminoh
You could just change those 4 lines as follows.
x_train = trainset.data
x_test = testset.data
y_train = trainset.targets
y_test = testset.targets

  • img = Image.fromarray(img) this line is needed before the transformation as @agoel10 said.
  • Running your code, I've found out that target values are actually returned wrt ordering of your queries, such as cat and dog above meaning that cat is re-assigned to a target value 0 and dog to 1. How can I keep and return the original target values?
@venushkaT

This comment has been minimized.

Copy link

@venushkaT venushkaT commented Nov 4, 2019

@SeoHyeong how do you print trainsetLoader/ testsetLoader target values. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.