Skip to content

Instantly share code, notes, and snippets.

@CLozy
Created October 13, 2021 10:12
Show Gist options
  • Save CLozy/f3e20334880915fc9281fdfdbb4bfeaa to your computer and use it in GitHub Desktop.
Save CLozy/f3e20334880915fc9281fdfdbb4bfeaa to your computer and use it in GitHub Desktop.
dataaugmentation
#data augmentation
train_transforms = transforms.Compose([
#transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.RandomResizedCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
test_transforms = transforms.Compose([
#transforms.ToPILImage(),
transforms.RandomResizedCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
#load training and test and pass the transforms
train_data = ImageDataset(csv_file=r'gdrive/My Drive/Malaria/train.csv', transform=train_transforms)
test_data = ImageDataset(csv_file=r'gdrive/My Drive/Malaria/test.csv', transform=test_transforms)
#obtaining indices that will be used for validation
val_size = 0.1 # % of validation set to be used
num_train = len(train_data)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(val_size * num_train))
train_idx, val_idx = indices[split:], indices[:split]
#define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
#prepare DataLoaders
batch_size = 20 #samples per batch to load
num_workers = 0 #no. of subprocesses to use for data loading
train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler)
val_loader = DataLoader(train_data, batch_size=batch_size, num_workers=num_workers, sampler=val_sampler)
test_loader = DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment