Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kingsj0405/79d6f6c592ee23f55cb7708d43cea197 to your computer and use it in GitHub Desktop.
Save kingsj0405/79d6f6c592ee23f55cb7708d43cea197 to your computer and use it in GitHub Desktop.

Result

(venv) root@host:path/FFHQ# python convert_lmdb1024_list256.py
70000it [14:16, 81.68it/s]________________________________________________________________________________________________________________________________________________________________________________________________________________________________| 69997/70000 [14:16<00:00, 26.75it/s]
100%|_____________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________| 70000/70000 [14:16<00:00, 81.68it/s]

convert_ffhq_lmdb1024_dir256_multiprocessing_tqdm_htop

import os
import pickle
import string
import lmdb
import cv2
import numpy as np
from multiprocessing import Pool
from tqdm import tqdm
class LmdbLoader(object):
"""Defines a class to load lmdb file.
This is a static class, which is used to solve lmdb loading error
when num_workers > 0
"""
files = dict()
@staticmethod
def get_lmdbfile(file_path):
"""Fetches a lmdb file"""
lmdb_files = LmdbLoader.files
if 'env' not in lmdb_files:
env = lmdb.open(file_path,
max_readers=1,
readonly=True,
lock=False,
readahead=False,
meminit=False)
with env.begin(write=False) as txn:
num_samples = txn.stat()['entries']
cache_file = '_cache_' + ''.join(
c for c in file_path if c in string.ascii_letters)
if os.path.isfile(cache_file):
keys = pickle.load(open(cache_file, "rb"))
else:
with env.begin(write=False) as txn:
keys = [key for key, _ in txn.cursor()]
pickle.dump(keys, open(cache_file, "wb"))
lmdb_files['env'] = env
lmdb_files['num_samples'] = num_samples
lmdb_files['keys'] = keys
return lmdb_files
@staticmethod
def get_image(file_path, idx):
"""Decodes an image from a particular lmdb file"""
lmdb_files = LmdbLoader.get_lmdbfile(file_path)
env = lmdb_files['env']
keys = lmdb_files['keys']
with env.begin(write=False) as txn:
imagebuf = txn.get(keys[idx])
image_np = np.frombuffer(imagebuf, np.uint8)
image = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
return image
def resize_and_save_one(args):
image = LmdbLoader.get_image(args['pre_root_dir'], args['index'])
image = cv2.resize(image, dsize=new_size, interpolation=cv2.INTER_AREA)
cv2.imwrite(args['new_image_path'], image)
def resize_and_save(pre_root_dir, new_root_dir, new_size=(256,256), train_ratio=0.9, thread_num=8):
# Initialize parameters
lmdb_file = LmdbLoader.get_lmdbfile(pre_root_dir)
num_samples = lmdb_file['num_samples']
train_cnt = int(num_samples * train_ratio)
new_train_dir = f'{new_root_dir}/train'
new_valid_dir = f'{new_root_dir}/val'
# Generate parameters for each thread
parameters = []
for i in range(num_samples):
image_name = lmdb_file['keys'][i].decode()[:8]
if i < train_cnt:
new_image_path = f'{new_train_dir}/{image_name}.png'
else:
new_image_path = f'{new_valid_dir}/{image_name}.png'
parameter = {
'pre_root_dir': pre_root_dir,
'index': i,
'new_image_path': new_image_path
}
parameters.append(parameter)
# Iterate using thread pool
pool = Pool(thread_num)
with tqdm(total=num_samples) as pbar:
for _ in tqdm(pool.imap_unordered(resize_and_save_one, parameters)):
pbar.update()
pool.close()
pool.join()
if __name__ == '__main__':
pre_root_dir = 'ffhq_1024.lmdb'
new_root_dir = 'ffhq_256.dir'
new_size = (256, 256)
train_ratio = 0.99
thread_num = 8
resize_and_save(pre_root_dir, new_root_dir, new_size, train_ratio, thread_num)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment