Last active
March 5, 2021 08:08
-
-
Save mk-devc/f1b2abcdea3bfd4decf8b6f13a01c2b3 to your computer and use it in GitHub Desktop.
This is a custom dataset for loading celebA dataset to be trained on google colab. Due to google drive limitations this could serve as a manual alternative.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Download the data on your google drive by making a copy of the annotations | |
# change the shareable link file/d/ into uc?id=/ | |
# store those links in their respective url | |
# use the gdown command on the url to download straight to colab | |
import pandas as pd | |
import os | |
from skimage import io | |
from torch.utils.data import DataLoader | |
identity=pd.read_csv('/content/identity_CelebA.txt',delim_whitespace=True, header=None) | |
bbox=pd.read_csv('/content/list_bbox_celeba.txt', delim_whitespace=True, header=1) | |
landmarks_align =pd.read_csv('/content/list_landmarks_align_celeba.txt',delim_whitespace=True, header=1) | |
attr=pd.read_csv('/content/list_attr_celeba.txt', delim_whitespace=True,header=1) | |
# Define the custom class | |
class celebADataset(Dataset): | |
def __init__(self,root_dir,labels=None,target_type="identity", transforms=None): | |
''' | |
initilaize the important passed variables | |
''' | |
self.annotations = labels | |
self.root_dir=root_dir | |
self.transforms=transforms | |
self.target_type=target_type | |
def __len__(self): | |
''' | |
returns the no of dataset | |
''' | |
return len(self.annotations) | |
def __getitem__(self, index): | |
''' | |
returns image in numpy with labels | |
''' | |
target=[] | |
# due to different files being read differently some will raise error due to how it's stored | |
if self.target_type == "identity": | |
img_path = os.path.join(self.root_dir, str(self.annotations.iloc[index, 0])) | |
image = io.imread(img_path) | |
target.append(self.annotations.iloc[index, 1]) | |
elif self.target_type == "attr": | |
img_path = os.path.join(self.root_dir, str(self.annotations.index[index])) | |
image = io.imread(img_path) | |
target.append(self.annotations.iloc[index,:].to_numpy()) | |
elif self.target_type== "bbox": | |
img_path = os.path.join(self.root_dir, self.annotations.loc[:,'image_id'][index]) | |
image = io.imread(img_path) | |
target.append(self.annotations[index,:].to_numpy()) | |
elif self.target_type == "landmarks": | |
img_path = os.path.join(self.root_dir, self.annotations.index[index]) | |
image = io.imread(img_path) | |
target.append(self.annotations.iloc[index,:].to_numpy()) | |
else: | |
raise ValueError("Target type \"{}\" is not recognized.".format(target_type)) | |
if self.transforms: | |
image = self.transforms(image) | |
if target: | |
target = tuple(target) if len(target) > 1 else target[0] | |
else: | |
target = None | |
return (image,target) | |
# example of loading the dataset | |
image_size = 64 | |
root='/content/img_align_celeba' | |
transform = transforms.Compose( | |
[transforms.ToPILImage(), | |
transforms.Resize(image_size), | |
transforms.CenterCrop(image_size), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,0.5,0.5 ), (0.5,0.5,0.5 )) | |
]) | |
dataset = celebADataset( | |
labels=identity, | |
root_dir=root, | |
transforms=transform | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment