Created
January 25, 2019 21:46
-
-
Save erykml/0ddea420d29eaf255beadd2ad3545bd3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 1. defining parameters ---- | |
# number of subprocesses to use for data loading | |
num_workers = 0 | |
# number of samples to load per batch | |
batch_size = 32 | |
# % of training set to use as validation | |
valid_size = 0.2 | |
# define transformations that will be applied to images | |
image_transforms = transforms.Compose([transforms.Resize((128, 128)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |
# 2. define the datasets ---- | |
train_data = datasets.ImageFolder(root='data/training_set/', transform=image_transforms) | |
test_data = datasets.ImageFolder(root='data/test_set/', transform=image_transforms) | |
# 3. obtain indices that will be used for validation ---- | |
num_train = len(train_data) | |
indices = list(range(num_train)) | |
np.random.shuffle(indices) | |
split = int(np.floor(valid_size * num_train)) | |
train_idx, valid_idx = indices[split:], indices[:split] | |
# define samplers for obtaining training and validation batches | |
train_sampler = SubsetRandomSampler(train_idx) | |
valid_sampler = SubsetRandomSampler(valid_idx) | |
# 4. prepare data loaders (combine dataset with sampler) ---- | |
train_loader = torch.utils.data.DataLoader(train_data, | |
batch_size = batch_size, | |
sampler = train_sampler, | |
num_workers = num_workers, | |
pin_memory = pin_memory) | |
valid_loader = torch.utils.data.DataLoader(train_data, | |
batch_size = batch_size, | |
sampler = valid_sampler, | |
num_workers = num_workers, | |
pin_memory = pin_memory) | |
test_loader = torch.utils.data.DataLoader(test_data, | |
shuffle = True, | |
batch_size = batch_size, | |
num_workers = num_workers, | |
pin_memory = pin_memory) | |
# specify the image classes ---- | |
classes = ['mario', 'wario'] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment