Skip to content

Instantly share code, notes, and snippets.

Created September 19, 2018 16:34
Show Gist options
  • Save christopher-beckham/7fa3b258bc9ba361b921af407a051303 to your computer and use it in GitHub Desktop.
Save christopher-beckham/7fa3b258bc9ba361b921af407a051303 to your computer and use it in GitHub Desktop.
import glob
import random
import os
import numpy as np
import torch
from import Dataset
from PIL import Image
import torchvision.transforms as transforms
class CelebADataset(Dataset):
def __init__(self, root, transforms_=None, mode='train',
attributes=[], missing_ind=False):
self.transform = transforms.Compose(transforms_)
self.selected_attrs = attributes
self.files = sorted(glob.glob('%s/*.jpg' % root))
self.files = self.files[:-2000] if mode == 'train' else self.files[-2000:]
self.label_path = "%s/list_attr_celeba.txt" % root
self.missing_ind = missing_ind
self.annotations = self.get_annotations()
self.keys = list(self.annotations.keys())
def get_annotations(self):
"""Extracts annotations for CelebA"""
annotations = {}
lines = [line.rstrip() for line in open(self.label_path, 'r')]
self.label_names = lines[1].split()
for _, line in enumerate(lines[2:]):
filename, *values = line.split()
labels = []
for attr in self.selected_attrs:
idx = self.label_names.index(attr)
labels.append(1 * (values[idx] == '1'))
if self.missing_ind:
# Basically add a label saying this is the
# 'everything else' class.
if 1 not in labels:
annotations[filename] = labels
return annotations
def __getitem__(self, index):
filepath = self.files[index % len(self.files)]
filename = filepath.split('/')[-1]
img = self.transform(
label = self.annotations[filename]
label = torch.FloatTensor(np.array(label))
if len(self.selected_attrs) == 0:
return img
return img, label
def __len__(self):
return len(self.files)
Copy link

root -> this is the folder with the contents of
list_attr_celeba.txt -> (this must be in the same folder as that defined by root)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment