Created
October 13, 2021 10:12
-
-
Save CLozy/f3e20334880915fc9281fdfdbb4bfeaa to your computer and use it in GitHub Desktop.
dataaugmentation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#data augmentation | |
train_transforms = transforms.Compose([ | |
#transforms.ToPILImage(), | |
transforms.RandomHorizontalFlip(), | |
transforms.RandomRotation(10), | |
transforms.RandomResizedCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], | |
[0.229, 0.224, 0.225])]) | |
test_transforms = transforms.Compose([ | |
#transforms.ToPILImage(), | |
transforms.RandomResizedCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], | |
[0.229, 0.224, 0.225])]) | |
#load training and test and pass the transforms | |
train_data = ImageDataset(csv_file=r'gdrive/My Drive/Malaria/train.csv', transform=train_transforms) | |
test_data = ImageDataset(csv_file=r'gdrive/My Drive/Malaria/test.csv', transform=test_transforms) | |
#obtaining indices that will be used for validation | |
val_size = 0.1 # % of validation set to be used | |
num_train = len(train_data) | |
indices = list(range(num_train)) | |
np.random.shuffle(indices) | |
split = int(np.floor(val_size * num_train)) | |
train_idx, val_idx = indices[split:], indices[:split] | |
#define samplers for obtaining training and validation batches | |
train_sampler = SubsetRandomSampler(train_idx) | |
val_sampler = SubsetRandomSampler(val_idx) | |
#prepare DataLoaders | |
batch_size = 20 #samples per batch to load | |
num_workers = 0 #no. of subprocesses to use for data loading | |
train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler) | |
val_loader = DataLoader(train_data, batch_size=batch_size, num_workers=num_workers, sampler=val_sampler) | |
test_loader = DataLoader(test_data, batch_size=batch_size, num_workers=num_workers) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment