[parallel tfrecords] create tfrecords with multi threads #ml #tf
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import re | |
import sys | |
import time | |
import json | |
import shutil | |
import random | |
import hashlib | |
import subprocess | |
import threading | |
import numpy as np | |
import pandas as pd | |
from datetime import timedelta | |
from tqdm import tqdm, tqdm_notebook, tnrange | |
import tensorflow as tf | |
from tensorflow.python.util import compat | |
from easydict import EasyDict as edict | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import h5py | |
%matplotlib inline | |
plt.rcParams['figure.figsize'] = (16.0, 4.0) | |
def get_time_str(): | |
return time.strftime("%m%d_%H%M%S", time.localtime()) | |
def touch_dir(dir): | |
if not os.path.exists(dir): | |
os.makedirs(dir) | |
def touch_and_clean_dir(dir): | |
if os.path.exists(dir): | |
shutil.rmtree(dir) | |
os.makedirs(dir) | |
# define helper workers | |
class Task(object): | |
def __init__(self, group, address_book, name): | |
self.group = group | |
self.address_book = address_book | |
self.count = len(address_book) | |
self.name = name | |
worker_tasks = threading.local() | |
class Worker(threading.Thread): | |
def __init__(self, name, output_dir, parse_func, examples_numbers, threadLock, tasks=None): | |
super(Worker, self).__init__() | |
self.name = name | |
self.output_dir = output_dir | |
self.tasks = tasks if tasks else [] | |
self.parse_func = parse_func | |
self.examples_numbers = examples_numbers | |
self.threadLock = threadLock | |
def add_task(self, task): | |
self.tasks.append(task) | |
def run(self) -> None: | |
print("Starting worker({})".format(self.name)) | |
for task in self.tasks: | |
tf_writer = tf.python_io.TFRecordWriter(os.path.join(self.output_dir, task.name)) | |
failed = 0 | |
for file_path in tqdm_notebook(task.address_book, desc=task.name): | |
try: | |
tf_example = self.parse_func(file_path) | |
if tf_example: | |
tf_writer.write(tf_example.SerializeToString()) | |
else: | |
failed += 1 | |
except Exception as e: | |
print('Exception when parse: ', e) | |
failed += 1 | |
tf_writer.close() | |
if failed>0: | |
self.threadLock.acquire() | |
print("{} files failed and skipped".format(str(failed))) | |
# TODO how to modify this example_numbers while its located in ParalledFarm | |
self.examples_numbers[task.group] -= failed | |
self.threadLock.release() | |
class ParallelFarm(): | |
def __init__(self, example_address_book, output_dir, parse_func, worker_num, group_percent, max_per_file): | |
self.example_address_book = example_address_book | |
self.output_dir = output_dir | |
self.parse_func = parse_func | |
self.worker_num = worker_num | |
self.group_percent = group_percent | |
self.max_per_file = max_per_file | |
self.examples_numbers = {} | |
self.address_book_of_group = {} | |
self.tasks = [] | |
self.workers = [] | |
self.source_list = [] | |
self.batch_list = {} | |
self.examples_numbers['train'] = int(len(example_address_book) * group_percent['train']) | |
self.examples_numbers['val'] = int(len(example_address_book) * group_percent['val']) | |
self.examples_numbers['test'] = len(example_address_book) - self.examples_numbers['train'] - self.examples_numbers['val'] | |
for group in ['train', 'val', 'test']: | |
self.address_book_of_group[group] = random.sample(example_address_book, self.examples_numbers[group]) | |
self.batch_list[group] = [] | |
def split_tasks(self): | |
# loop examples and split into tasks | |
for group in ['train', 'val', 'test']: | |
task_id = 0 | |
start = 0 | |
end = min(self.max_per_file, len(self.address_book_of_group[group])) | |
while end < len(self.address_book_of_group[group]): | |
if task_id == 0: | |
task_name = group + '.tfrecords' | |
else: | |
task_name = group + '_' + str(task_id) + '.tfrecords' | |
new_task = Task(group, self.address_book_of_group[group][start:end], name=task_name) | |
self.tasks.append(new_task) | |
self.batch_list[group].append(task_name) | |
start += self.max_per_file | |
end += self.max_per_file | |
task_id += 1 | |
if task_id == 0: | |
task_name = group + '.tfrecords' | |
else: | |
task_name = group + '_' + str(task_id) + '.tfrecords' | |
end = len(self.address_book_of_group[group]) | |
new_task = Task(group, self.address_book_of_group[group][start:end], name=task_name) | |
self.tasks.append(new_task) | |
self.batch_list[group].append(task_name) | |
def run(self): | |
self.split_tasks() | |
threadLock = threading.Lock() | |
# init workers | |
for i in range(0, self.worker_num): | |
worker_name = 'worker'+str(i) | |
new_worker = Worker(worker_name, self.output_dir, self.parse_func, self.examples_numbers, threadLock) | |
self.workers.append(new_worker) | |
# assign tasks to workers | |
for i in range(0, len(self.tasks)): | |
self.workers[i % self.worker_num].add_task(self.tasks[i]) | |
# start workers | |
for worker in self.workers: | |
worker.start() | |
# wait workers | |
for worker in self.workers: | |
worker.join() | |
def dump_meta_json(self, output_dir=None, extra_params=None): | |
if not output_dir: | |
output_dir = self.output_dir | |
with open(os.path.join(output_dir, 'meta.json'), 'w') as f: | |
content = { | |
'num_examples': self.examples_numbers, | |
'train_batch_list': self.batch_list['train'], | |
'val_batch_list': self.batch_list['val'], | |
'test_batch_list': self.batch_list['test'], | |
} | |
if extra_params: | |
for k,v in extra_params.items(): | |
content[k] = v | |
json.dump(content, f) | |
# define example gather and parser | |
def walkdir(rootdir): | |
for node in os.listdir(rootdir): | |
path = os.path.join(rootdir, node) | |
if os.path.isdir(path): | |
walkdir(path) | |
if os.path.splitext(node)[1] in ['.jpg', '.png', '.jpeg']: | |
example_address_book.append(path) | |
TARGET_SIZE = 512 | |
def parse_example(file_path): | |
filename = os.path.basename(file_path) | |
index = os.path.splitext(filename)[0] | |
mask_path = os.path.join(mask_dir, index + '.png') | |
if not os.path.exists(mask_path): | |
raise IOError(str(mask_path) + ' not found.') | |
image = Image.open(file_path) | |
image = image.convert('RGB') | |
image = image.resize((TARGET_SIZE, TARGET_SIZE)) | |
image_data = np.array(image).tobytes() | |
mask_img = Image.open(mask_path) | |
mask_img = mask_img.resize((TARGET_SIZE, TARGET_SIZE)) | |
mask = np.uint8(np.array(mask_img)) | |
mask_data = mask.tobytes() | |
portrait_mask = np.uint8(mask>0) | |
portrait_mask_data = portrait_mask.tobytes() | |
tf_example = tf.train.Example(features=tf.train.Features(feature={ | |
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data])), | |
'mask': tf.train.Feature(bytes_list=tf.train.BytesList(value=[mask_data])), | |
'portrait_mask': tf.train.Feature(bytes_list=tf.train.BytesList(value=[portrait_mask_data])) | |
})) | |
return tf_example | |
# config params and start threads | |
DATASET_PARAMS_DICT = { | |
"percent": { | |
"train": 0.95, | |
"val": 0.05, | |
"test": 0.0 | |
}, | |
"max_per_file": 3000, | |
"sample_per_class_max": 0, # whether balance sample cnt for classes and the max cnt | |
"parallel_num": 10 #30 | |
} | |
PARAMS = edict(DATASET_PARAMS_DICT) | |
ROOT_DIR = os.path.join("/home/dl/i19tdata/proj_portrat2anime", 'CelebAMask-HQ') | |
face_dir = os.path.join(ROOT_DIR, 'CelebA-HQ-img') | |
mask_dir = os.path.join(ROOT_DIR, 'CelebAMask-total-mask') | |
target_dir = os.path.join(ROOT_DIR, 'facial_segmentation_tfrecords_' + get_time_str()) | |
touch_and_clean_dir(target_dir) | |
example_address_book = [] | |
walkdir(face_dir) | |
random.shuffle(example_address_book) | |
farm = ParallelFarm(example_address_book, | |
output_dir = target_dir, | |
parse_func=parse_example, | |
worker_num=PARAMS.parallel_num, | |
group_percent=PARAMS.percent, | |
max_per_file=PARAMS.max_per_file) | |
farm.run() | |
farm.dump_meta_json() | |
#Rename those tfrecords | |
#====================================== | |
import os | |
import shutil | |
def rename_tfrecord_with_meta(meta_json_path): | |
assert(os.path.exists(meta_json_path)) | |
base_dir = os.path.abspath(os.path.join(meta_json_path, '..')) | |
meta = {} | |
with open(meta_json_path, 'r') as f: | |
line = f.read() | |
meta = json.loads(line) | |
group_tag = ['train', 'val', 'test'] | |
shards = {} | |
for group in group_tag: | |
group_key = group+'_batch_list' | |
shards[group] = len(meta[group_key]) | |
new_batch_filenames = [] | |
for i in range(shards[group]): | |
if i==0: | |
original_name = group+'.tfrecords' | |
else: | |
original_name = group+'_'+str(i)+'.tfrecords' | |
new_name = '{}-{:05d}-of-{:05d}.tfrecord'.format(group, i, shards[group]) | |
new_batch_filenames.append(new_name) | |
shutil.move(os.path.join(base_dir, original_name), os.path.join(base_dir, new_name)) | |
meta[group_key] = new_batch_filenames | |
# write new meta back | |
with open(meta_json_path, 'w') as fout: | |
fout.write(json.dumps(meta, indent=4)) | |
rename_tfrecord_with_meta(os.path.join(TEST_DIR, 'test_meta.json')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment