Created
December 12, 2016 06:52
-
-
Save kyamagu/4f47cc7e06401141b34fa642c5170258 to your computer and use it in GitHub Desktop.
Script to convert CelebA dataset to LMDB format
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
''' | |
Example: | |
python data/celeba/scripts/build_dataset.py \ | |
--output_dir data/celeba/ | |
./build/tools/compute_image_mean \ | |
data/celeba/train-images.lmdb \ | |
data/celeba/mean.binaryproto | |
''' | |
import argparse | |
import itertools | |
import hashlib | |
import lmdb | |
import logging | |
import numpy as np | |
import os | |
import shutil | |
import sys | |
from cStringIO import StringIO | |
sys.path.append(os.path.join(os.getenv('CAFFE_ROOT', './'), 'python')) | |
from caffe.proto.caffe_pb2 import Datum | |
# Default names for input data. | |
IMAGES = 'data/celeba/Img/img_align_celeba/' | |
SPLITS = 'data/celeba/Eval/list_eval_partition.txt' | |
ATTRIBUTES = 'data/celeba/Anno/list_attr_celeba.txt' | |
def _clean_old_lmdb(args): | |
'''Clean up existing LMDBs. | |
''' | |
for split in ["train", "val", "test"]: | |
output_dir = os.path.join(args.output_dir, split + "-images.lmdb") | |
if os.path.exists(output_dir): | |
logging.info("Removing existing: {0}".format(output_dir)) | |
shutil.rmtree(output_dir) | |
output_dir = os.path.join(args.output_dir, split + "-labels.lmdb") | |
if os.path.exists(output_dir): | |
logging.info("Removing existing: {0}".format(output_dir)) | |
shutil.rmtree(output_dir) | |
def load_attributes(filename): | |
with open(filename, 'r') as f: | |
num_images = int(f.readline().strip()) | |
print("{} records".format(num_images)) | |
attributes = f.readline().strip().split(" ") | |
print("{} attributes".format(len(attributes))) | |
print("{}".format(attributes)) | |
files = np.loadtxt(f, usecols=[0], dtype=np.str) | |
f.seek(0) | |
data = np.loadtxt(f, usecols=[i + 1 for i in xrange(len(attributes))], | |
dtype=np.int, skiprows=2) > 0 | |
assert files.size == data.shape[0] | |
print("Finished loading {}".format(filename)) | |
return {'files': files, 'attribute_names': attributes, 'attributes': data} | |
def load_splits(filename): | |
with open(filename, 'r') as f: | |
splits = np.loadtxt(f, usecols=(1,), dtype=np.int) | |
print("Finished loading {}".format(filename)) | |
return splits | |
def read_input(args): | |
# Load input data. | |
splits = load_splits(args.input_splits) | |
data = load_attributes(args.input_attributes) | |
assert data['files'].size == splits.size | |
image_dir = args.input_images | |
for i in xrange(data['files'].size): | |
image_path = os.path.join(image_dir, data['files'][i]) | |
yield image_path, data['attributes'][i, :], splits[i] | |
def grouper(iterable, n): | |
it = iter(iterable) | |
while True: | |
chunk = tuple(itertools.islice(it, n)) | |
if not chunk: | |
return | |
yield chunk | |
def _create_lmdb(args): | |
# Create Image LMDB files. | |
count = 0 | |
with lmdb.open(os.path.join(args.output_dir, "train-images.lmdb"), | |
map_size=(1<<40), create=True) as env1, \ | |
lmdb.open(os.path.join(args.output_dir, "val-images.lmdb"), | |
map_size=(1<<40), create=True) as env2, \ | |
lmdb.open(os.path.join(args.output_dir, "test-images.lmdb"), | |
map_size=(1<<40), create=True) as env3, \ | |
lmdb.open(os.path.join(args.output_dir, "train-labels.lmdb"), | |
map_size=(1<<40), create=True) as env4, \ | |
lmdb.open(os.path.join(args.output_dir, "val-labels.lmdb"), | |
map_size=(1<<40), create=True) as env5, \ | |
lmdb.open(os.path.join(args.output_dir, "test-labels.lmdb"), | |
map_size=(1<<40), create=True) as env6: | |
for batch in grouper(read_input(args), args.batch_size): | |
with env1.begin(write=True) as txn1, \ | |
env2.begin(write=True) as txn2, \ | |
env3.begin(write=True) as txn3, \ | |
env4.begin(write=True) as txn4, \ | |
env5.begin(write=True) as txn5, \ | |
env6.begin(write=True) as txn6: | |
for image_path, labels, split in batch: | |
key = hashlib.md5(image_path).hexdigest() | |
# Image: 20% val, 20% test, 60% train. | |
datum = Datum() | |
datum.data = open(image_path, 'rb').read() | |
datum.encoded = True | |
datum.label = 0 | |
txn = {1: txn2, 2: txn3}.get(split, txn1) | |
txn.put(key, datum.SerializeToString()) | |
# Labels. | |
datum = Datum() | |
datum.data = labels.tobytes() | |
datum.channels = len(labels) | |
datum.width = 1 | |
datum.height = 1 | |
datum.label = 0 | |
txn = {1: txn5, 2: txn6}.get(split, txn4) | |
txn.put(key, datum.SerializeToString()) | |
count += 1 | |
logging.info("Processed {0}".format(count)) | |
def _main(args): | |
_clean_old_lmdb(args) | |
_create_lmdb(args) | |
if __name__ == '__main__': | |
logging.basicConfig(format='[%(asctime)s] %(message)s', | |
level=logging.DEBUG) | |
parser = argparse.ArgumentParser(description='Create LMDB files.') | |
parser.add_argument('--input_images', type=str, default=IMAGES, | |
help='Input image directory.') | |
parser.add_argument('--input_attributes', type=str, default=ATTRIBUTES, | |
help='Input annotation text file.') | |
parser.add_argument('--input_splits', type=str, default=SPLITS, | |
help='Input split text file.') | |
parser.add_argument('--output_dir', type=str, default="data/celeba", | |
help='Output directory.') | |
parser.add_argument('--batch_size', type=int, default=1000, | |
help='Batch size.') | |
_main(parser.parse_args()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment