Skip to content

Instantly share code, notes, and snippets.

@kyamagu
Created December 12, 2016 06:52
Show Gist options
  • Save kyamagu/4f47cc7e06401141b34fa642c5170258 to your computer and use it in GitHub Desktop.
Save kyamagu/4f47cc7e06401141b34fa642c5170258 to your computer and use it in GitHub Desktop.
Script to convert CelebA dataset to LMDB format
#!/usr/bin/env python
'''
Example:
python data/celeba/scripts/build_dataset.py \
--output_dir data/celeba/
./build/tools/compute_image_mean \
data/celeba/train-images.lmdb \
data/celeba/mean.binaryproto
'''
import argparse
import itertools
import hashlib
import lmdb
import logging
import numpy as np
import os
import shutil
import sys
from cStringIO import StringIO
sys.path.append(os.path.join(os.getenv('CAFFE_ROOT', './'), 'python'))
from caffe.proto.caffe_pb2 import Datum
# Default names for input data.
IMAGES = 'data/celeba/Img/img_align_celeba/'
SPLITS = 'data/celeba/Eval/list_eval_partition.txt'
ATTRIBUTES = 'data/celeba/Anno/list_attr_celeba.txt'
def _clean_old_lmdb(args):
'''Clean up existing LMDBs.
'''
for split in ["train", "val", "test"]:
output_dir = os.path.join(args.output_dir, split + "-images.lmdb")
if os.path.exists(output_dir):
logging.info("Removing existing: {0}".format(output_dir))
shutil.rmtree(output_dir)
output_dir = os.path.join(args.output_dir, split + "-labels.lmdb")
if os.path.exists(output_dir):
logging.info("Removing existing: {0}".format(output_dir))
shutil.rmtree(output_dir)
def load_attributes(filename):
with open(filename, 'r') as f:
num_images = int(f.readline().strip())
print("{} records".format(num_images))
attributes = f.readline().strip().split(" ")
print("{} attributes".format(len(attributes)))
print("{}".format(attributes))
files = np.loadtxt(f, usecols=[0], dtype=np.str)
f.seek(0)
data = np.loadtxt(f, usecols=[i + 1 for i in xrange(len(attributes))],
dtype=np.int, skiprows=2) > 0
assert files.size == data.shape[0]
print("Finished loading {}".format(filename))
return {'files': files, 'attribute_names': attributes, 'attributes': data}
def load_splits(filename):
with open(filename, 'r') as f:
splits = np.loadtxt(f, usecols=(1,), dtype=np.int)
print("Finished loading {}".format(filename))
return splits
def read_input(args):
# Load input data.
splits = load_splits(args.input_splits)
data = load_attributes(args.input_attributes)
assert data['files'].size == splits.size
image_dir = args.input_images
for i in xrange(data['files'].size):
image_path = os.path.join(image_dir, data['files'][i])
yield image_path, data['attributes'][i, :], splits[i]
def grouper(iterable, n):
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, n))
if not chunk:
return
yield chunk
def _create_lmdb(args):
# Create Image LMDB files.
count = 0
with lmdb.open(os.path.join(args.output_dir, "train-images.lmdb"),
map_size=(1<<40), create=True) as env1, \
lmdb.open(os.path.join(args.output_dir, "val-images.lmdb"),
map_size=(1<<40), create=True) as env2, \
lmdb.open(os.path.join(args.output_dir, "test-images.lmdb"),
map_size=(1<<40), create=True) as env3, \
lmdb.open(os.path.join(args.output_dir, "train-labels.lmdb"),
map_size=(1<<40), create=True) as env4, \
lmdb.open(os.path.join(args.output_dir, "val-labels.lmdb"),
map_size=(1<<40), create=True) as env5, \
lmdb.open(os.path.join(args.output_dir, "test-labels.lmdb"),
map_size=(1<<40), create=True) as env6:
for batch in grouper(read_input(args), args.batch_size):
with env1.begin(write=True) as txn1, \
env2.begin(write=True) as txn2, \
env3.begin(write=True) as txn3, \
env4.begin(write=True) as txn4, \
env5.begin(write=True) as txn5, \
env6.begin(write=True) as txn6:
for image_path, labels, split in batch:
key = hashlib.md5(image_path).hexdigest()
# Image: 20% val, 20% test, 60% train.
datum = Datum()
datum.data = open(image_path, 'rb').read()
datum.encoded = True
datum.label = 0
txn = {1: txn2, 2: txn3}.get(split, txn1)
txn.put(key, datum.SerializeToString())
# Labels.
datum = Datum()
datum.data = labels.tobytes()
datum.channels = len(labels)
datum.width = 1
datum.height = 1
datum.label = 0
txn = {1: txn5, 2: txn6}.get(split, txn4)
txn.put(key, datum.SerializeToString())
count += 1
logging.info("Processed {0}".format(count))
def _main(args):
_clean_old_lmdb(args)
_create_lmdb(args)
if __name__ == '__main__':
logging.basicConfig(format='[%(asctime)s] %(message)s',
level=logging.DEBUG)
parser = argparse.ArgumentParser(description='Create LMDB files.')
parser.add_argument('--input_images', type=str, default=IMAGES,
help='Input image directory.')
parser.add_argument('--input_attributes', type=str, default=ATTRIBUTES,
help='Input annotation text file.')
parser.add_argument('--input_splits', type=str, default=SPLITS,
help='Input split text file.')
parser.add_argument('--output_dir', type=str, default="data/celeba",
help='Output directory.')
parser.add_argument('--batch_size', type=int, default=1000,
help='Batch size.')
_main(parser.parse_args())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment