Created
April 5, 2021 22:15
-
-
Save previtus/cb86dba0c0d9746dbfe9f7374514fbdd to your computer and use it in GitHub Desktop.
celeb a dataset at 32x32 resolution
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from PIL import Image | |
from torchvision import datasets, transforms | |
import torch | |
import torch.utils.data as data | |
import urllib.request | |
import scipy.io | |
import os | |
import imageio | |
import numpy as np | |
from os import listdir | |
import os.path | |
import time | |
## note, still needs some preprocessing ... | |
# download and unpack from: https://drive.google.com/file/d/1eKq1RzppY6FYHqG1j1CWrWEvWNksZJSB/view?usp=sharing | |
def load_celeb_a_32x32(cuda, batch_size, path = "datasets/CelebA32x32/", train_test_split = 0.9): | |
# note: download celeb a zip from | |
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {} | |
baked_path = path+"celeba_32x32.npy" | |
if os.path.isfile(baked_path): | |
data = np.load(baked_path) | |
print("data loaded as:", data.shape) | |
else: | |
images_paths = [path+"data32x32/" + f for f in listdir(path+"data32x32/") if ".jpg" in f] | |
images_paths.sort() | |
def load_img(image_path): | |
im = imageio.imread(image_path) # in RGB format | |
return im | |
data = [load_img(p) for p in images_paths] | |
data = np.asarray(data) | |
np.save(baked_path, data) | |
print("data loaded as:", data.shape) | |
# shuffle | |
np.random.seed(42) | |
np.random.shuffle(data) | |
# randomize again | |
t = 1000 * time.time() # current time in milliseconds | |
np.random.seed(int(t) % 2 ** 32) | |
# train-test split | |
split_idx = int(len(data) * train_test_split) | |
train = data[0:split_idx] | |
test = data[split_idx:] | |
print("split with train=", train.shape, "test=", test.shape) | |
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, **kwargs) | |
test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=True, **kwargs) | |
input_size = int(32*32*3) | |
return train_loader, test_loader, input_size | |
""" | |
cuda = False | |
batch_size = 20 | |
train_loader, test_loader, input_size = load_celeb_a_32x32(cuda,batch_size=batch_size) | |
print("train_loader", len(train_loader)*batch_size, "test_loader", len(test_loader)*batch_size, "input_size=",input_size) | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment