Skip to content

Instantly share code, notes, and snippets.

@McSpooder
Created July 22, 2020 18:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save McSpooder/f3af9a932ff43e9a422235a9d51eb28b to your computer and use it in GitHub Desktop.
Save McSpooder/f3af9a932ff43e9a422235a9d51eb28b to your computer and use it in GitHub Desktop.
Pytorch data loader
class MRIDataset(Dataset):
def __init__(self, root_dir, labels, transform=None):
self.root_dir = root_dir
self.transform = transform
self.directories = []
self.len = 0
self.labels = labels
self.clin_data = pd.read_csv("../data4/lon_clin.csv")
train_dirs = []
for label in labels:
train_dirs.append(root_dir + label)
for dir in train_dirs:
for path in glob.glob(dir + "/*"):
self.directories.append(pathlib.Path(path))
self.len = len(self.directories)
def __len__(self):
return self.len
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment