Skip to content

Instantly share code, notes, and snippets.

@youngjung
Created April 18, 2019 07:17
Show Gist options
  • Save youngjung/73626bd95d9ce309cd3ba97b963867d3 to your computer and use it in GitHub Desktop.
Save youngjung/73626bd95d9ce309cd3ba97b963867d3 to your computer and use it in GitHub Desktop.
create an imagefolder and save images in subdirs with classnames in cifar10
import os
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
import torchvision
import torchvision.transforms as transforms
label_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# tensor to PIL Image
def tensor2img(img):
img = img.cpu().float().numpy()
if img.shape[0] == 1:
img = np.tile(img, (3, 1, 1))
img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0
return img.astype(np.uint8)
def save_imgs(imgs, names, path):
for img, name in zip(imgs, names):
img = tensor2img(img)
img = Image.fromarray(img)
img.save(os.path.join(path, name + '.png'))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dir_dataset', type=str, required=True)
parser.add_argument('--dir_dest', type=str, required=True)
parser.add_argument('--img_size', type=int, default=32)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--num_workers', type=int, default=2)
opts = parser.parse_args()
# data loader
print('\n--- load dataset ---')
os.makedirs(opts.dir_dataset, exist_ok=True)
dataset = torchvision.datasets.CIFAR10(opts.dir_dataset, train=True, download=True,
transform=transforms.Compose([
transforms.Resize(opts.img_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
train_loader = torch.utils.data.DataLoader(dataset, batch_size=opts.batch_size, shuffle=True, num_workers=opts.num_workers)
# prepare dirs
for name in label_names:
os.makedirs(os.path.join(opts.dir_dest, name), exist_ok=True)
# run
print('\n--- run ---')
niter_per_ep = len(train_loader)
pbar = tqdm(enumerate(train_loader), total=niter_per_ep)
start = 0
for it, (images, label) in pbar:
stop = start + images.size(0)
names = ['{}/{:08d}'.format(label_names[l], i) for l, i in zip(label, range(start, stop))]
save_imgs(images, names, opts.dir_dest)
start = stop
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment