Skip to content

Instantly share code, notes, and snippets.

@erykml
Created January 25, 2019 21:46
Show Gist options
  • Save erykml/0ddea420d29eaf255beadd2ad3545bd3 to your computer and use it in GitHub Desktop.
Save erykml/0ddea420d29eaf255beadd2ad3545bd3 to your computer and use it in GitHub Desktop.
# 1. defining parameters ----
# number of subprocesses to use for data loading
num_workers = 0
# number of samples to load per batch
batch_size = 32
# % of training set to use as validation
valid_size = 0.2
# define transformations that will be applied to images
image_transforms = transforms.Compose([transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 2. define the datasets ----
train_data = datasets.ImageFolder(root='data/training_set/', transform=image_transforms)
test_data = datasets.ImageFolder(root='data/test_set/', transform=image_transforms)
# 3. obtain indices that will be used for validation ----
num_train = len(train_data)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
# 4. prepare data loaders (combine dataset with sampler) ----
train_loader = torch.utils.data.DataLoader(train_data,
batch_size = batch_size,
sampler = train_sampler,
num_workers = num_workers,
pin_memory = pin_memory)
valid_loader = torch.utils.data.DataLoader(train_data,
batch_size = batch_size,
sampler = valid_sampler,
num_workers = num_workers,
pin_memory = pin_memory)
test_loader = torch.utils.data.DataLoader(test_data,
shuffle = True,
batch_size = batch_size,
num_workers = num_workers,
pin_memory = pin_memory)
# specify the image classes ----
classes = ['mario', 'wario']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment