Skip to content

Instantly share code, notes, and snippets.

@pyaf
Created February 2, 2018 15:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pyaf/9a28637614d4dbe6d2cc7acf48cc81bd to your computer and use it in GitHub Desktop.
Save pyaf/9a28637614d4dbe6d2cc7acf48cc81bd to your computer and use it in GitHub Desktop.
Data pipeline and data augmentation on MURA dataset
data_cat = ['train', 'valid'] # data categories
class ImageDataset(Dataset):
"""training dataset."""
def __init__(self, df, transform=None):
"""
Args:
df (pd.DataFrame): a pandas DataFrame with image path and labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.df = df
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
study_path = self.df.iloc[idx, 0]
count = self.df.iloc[idx, 1]
images = []
for i in range(count):
image = pil_loader(study_path + 'image%s.png' % (i+1))
images.append(self.transform(image))
images = torch.stack(images)
label = self.df.iloc[idx, 2]
sample = {'images': images, 'label': label}
return sample
def get_dataloaders(data, batch_size=8, study_level=False):
'''
Returns dataloader pipeline with data augmentation
'''
data_transforms = {
'train': transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'valid': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
image_datasets = {x: ImageDataset(data[x], transform=data_transforms[x], study_level=study_level) for x in data_cat}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in data_cat}
return dataloaders
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment