Skip to content

Instantly share code, notes, and snippets.

@mk-devc
Last active March 5, 2021 08:08
Show Gist options
  • Save mk-devc/f1b2abcdea3bfd4decf8b6f13a01c2b3 to your computer and use it in GitHub Desktop.
Save mk-devc/f1b2abcdea3bfd4decf8b6f13a01c2b3 to your computer and use it in GitHub Desktop.
This is a custom dataset for loading celebA dataset to be trained on google colab. Due to google drive limitations this could serve as a manual alternative.
# Download the data on your google drive by making a copy of the annotations
# change the shareable link file/d/ into uc?id=/
# store those links in their respective url
# use the gdown command on the url to download straight to colab
import pandas as pd
import os
from skimage import io
from torch.utils.data import DataLoader
identity=pd.read_csv('/content/identity_CelebA.txt',delim_whitespace=True, header=None)
bbox=pd.read_csv('/content/list_bbox_celeba.txt', delim_whitespace=True, header=1)
landmarks_align =pd.read_csv('/content/list_landmarks_align_celeba.txt',delim_whitespace=True, header=1)
attr=pd.read_csv('/content/list_attr_celeba.txt', delim_whitespace=True,header=1)
# Define the custom class
class celebADataset(Dataset):
def __init__(self,root_dir,labels=None,target_type="identity", transforms=None):
'''
initilaize the important passed variables
'''
self.annotations = labels
self.root_dir=root_dir
self.transforms=transforms
self.target_type=target_type
def __len__(self):
'''
returns the no of dataset
'''
return len(self.annotations)
def __getitem__(self, index):
'''
returns image in numpy with labels
'''
target=[]
# due to different files being read differently some will raise error due to how it's stored
if self.target_type == "identity":
img_path = os.path.join(self.root_dir, str(self.annotations.iloc[index, 0]))
image = io.imread(img_path)
target.append(self.annotations.iloc[index, 1])
elif self.target_type == "attr":
img_path = os.path.join(self.root_dir, str(self.annotations.index[index]))
image = io.imread(img_path)
target.append(self.annotations.iloc[index,:].to_numpy())
elif self.target_type== "bbox":
img_path = os.path.join(self.root_dir, self.annotations.loc[:,'image_id'][index])
image = io.imread(img_path)
target.append(self.annotations[index,:].to_numpy())
elif self.target_type == "landmarks":
img_path = os.path.join(self.root_dir, self.annotations.index[index])
image = io.imread(img_path)
target.append(self.annotations.iloc[index,:].to_numpy())
else:
raise ValueError("Target type \"{}\" is not recognized.".format(target_type))
if self.transforms:
image = self.transforms(image)
if target:
target = tuple(target) if len(target) > 1 else target[0]
else:
target = None
return (image,target)
# example of loading the dataset
image_size = 64
root='/content/img_align_celeba'
transform = transforms.Compose(
[transforms.ToPILImage(),
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5 ), (0.5,0.5,0.5 ))
])
dataset = celebADataset(
labels=identity,
root_dir=root,
transforms=transform
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment