(venv) root@host:path/FFHQ# python convert_lmdb1024_list256.py
70000it [14:16, 81.68it/s]________________________________________________________________________________________________________________________________________________________________________________________________________________________________| 69997/70000 [14:16<00:00, 26.75it/s]
100%|_____________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________| 70000/70000 [14:16<00:00, 81.68it/s]
Last active
September 28, 2021 02:40
-
-
Save kingsj0405/79d6f6c592ee23f55cb7708d43cea197 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pickle | |
import string | |
import lmdb | |
import cv2 | |
import numpy as np | |
from multiprocessing import Pool | |
from tqdm import tqdm | |
class LmdbLoader(object): | |
"""Defines a class to load lmdb file. | |
This is a static class, which is used to solve lmdb loading error | |
when num_workers > 0 | |
""" | |
files = dict() | |
@staticmethod | |
def get_lmdbfile(file_path): | |
"""Fetches a lmdb file""" | |
lmdb_files = LmdbLoader.files | |
if 'env' not in lmdb_files: | |
env = lmdb.open(file_path, | |
max_readers=1, | |
readonly=True, | |
lock=False, | |
readahead=False, | |
meminit=False) | |
with env.begin(write=False) as txn: | |
num_samples = txn.stat()['entries'] | |
cache_file = '_cache_' + ''.join( | |
c for c in file_path if c in string.ascii_letters) | |
if os.path.isfile(cache_file): | |
keys = pickle.load(open(cache_file, "rb")) | |
else: | |
with env.begin(write=False) as txn: | |
keys = [key for key, _ in txn.cursor()] | |
pickle.dump(keys, open(cache_file, "wb")) | |
lmdb_files['env'] = env | |
lmdb_files['num_samples'] = num_samples | |
lmdb_files['keys'] = keys | |
return lmdb_files | |
@staticmethod | |
def get_image(file_path, idx): | |
"""Decodes an image from a particular lmdb file""" | |
lmdb_files = LmdbLoader.get_lmdbfile(file_path) | |
env = lmdb_files['env'] | |
keys = lmdb_files['keys'] | |
with env.begin(write=False) as txn: | |
imagebuf = txn.get(keys[idx]) | |
image_np = np.frombuffer(imagebuf, np.uint8) | |
image = cv2.imdecode(image_np, cv2.IMREAD_COLOR) | |
return image | |
def resize_and_save_one(args): | |
image = LmdbLoader.get_image(args['pre_root_dir'], args['index']) | |
image = cv2.resize(image, dsize=new_size, interpolation=cv2.INTER_AREA) | |
cv2.imwrite(args['new_image_path'], image) | |
def resize_and_save(pre_root_dir, new_root_dir, new_size=(256,256), train_ratio=0.9, thread_num=8): | |
# Initialize parameters | |
lmdb_file = LmdbLoader.get_lmdbfile(pre_root_dir) | |
num_samples = lmdb_file['num_samples'] | |
train_cnt = int(num_samples * train_ratio) | |
new_train_dir = f'{new_root_dir}/train' | |
new_valid_dir = f'{new_root_dir}/val' | |
# Generate parameters for each thread | |
parameters = [] | |
for i in range(num_samples): | |
image_name = lmdb_file['keys'][i].decode()[:8] | |
if i < train_cnt: | |
new_image_path = f'{new_train_dir}/{image_name}.png' | |
else: | |
new_image_path = f'{new_valid_dir}/{image_name}.png' | |
parameter = { | |
'pre_root_dir': pre_root_dir, | |
'index': i, | |
'new_image_path': new_image_path | |
} | |
parameters.append(parameter) | |
# Iterate using thread pool | |
pool = Pool(thread_num) | |
with tqdm(total=num_samples) as pbar: | |
for _ in tqdm(pool.imap_unordered(resize_and_save_one, parameters)): | |
pbar.update() | |
pool.close() | |
pool.join() | |
if __name__ == '__main__': | |
pre_root_dir = 'ffhq_1024.lmdb' | |
new_root_dir = 'ffhq_256.dir' | |
new_size = (256, 256) | |
train_ratio = 0.99 | |
thread_num = 8 | |
resize_and_save(pre_root_dir, new_root_dir, new_size, train_ratio, thread_num) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment