Last active
August 22, 2019 07:49
-
-
Save BenZstory/ec14e54f8b9ee5ce7e4b96e3f39d9302 to your computer and use it in GitHub Desktop.
[parallel tfrecords] create tfrecords with multi threads #ml #tf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import re | |
import sys | |
import time | |
import json | |
import shutil | |
import random | |
import hashlib | |
import subprocess | |
import threading | |
import numpy as np | |
import pandas as pd | |
from datetime import timedelta | |
from tqdm import tqdm, tqdm_notebook, tnrange | |
import tensorflow as tf | |
from tensorflow.python.util import compat | |
from easydict import EasyDict as edict | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import h5py | |
%matplotlib inline | |
plt.rcParams['figure.figsize'] = (16.0, 4.0) | |
def get_time_str(): | |
return time.strftime("%m%d_%H%M%S", time.localtime()) | |
def touch_dir(dir): | |
if not os.path.exists(dir): | |
os.makedirs(dir) | |
def touch_and_clean_dir(dir): | |
if os.path.exists(dir): | |
shutil.rmtree(dir) | |
os.makedirs(dir) | |
# define helper workers | |
class Task(object): | |
def __init__(self, group, address_book, name): | |
self.group = group | |
self.address_book = address_book | |
self.count = len(address_book) | |
self.name = name | |
worker_tasks = threading.local() | |
class Worker(threading.Thread): | |
def __init__(self, name, output_dir, parse_func, examples_numbers, threadLock, tasks=None): | |
super(Worker, self).__init__() | |
self.name = name | |
self.output_dir = output_dir | |
self.tasks = tasks if tasks else [] | |
self.parse_func = parse_func | |
self.examples_numbers = examples_numbers | |
self.threadLock = threadLock | |
def add_task(self, task): | |
self.tasks.append(task) | |
def run(self) -> None: | |
print("Starting worker({})".format(self.name)) | |
for task in self.tasks: | |
tf_writer = tf.python_io.TFRecordWriter(os.path.join(self.output_dir, task.name)) | |
failed = 0 | |
for file_path in tqdm_notebook(task.address_book, desc=task.name): | |
try: | |
tf_example = self.parse_func(file_path) | |
if tf_example: | |
tf_writer.write(tf_example.SerializeToString()) | |
else: | |
failed += 1 | |
except Exception as e: | |
print('Exception when parse: ', e) | |
failed += 1 | |
tf_writer.close() | |
if failed>0: | |
self.threadLock.acquire() | |
print("{} files failed and skipped".format(str(failed))) | |
# TODO how to modify this example_numbers while its located in ParalledFarm | |
self.examples_numbers[task.group] -= failed | |
self.threadLock.release() | |
class ParallelFarm(): | |
def __init__(self, example_address_book, output_dir, parse_func, worker_num, group_percent, max_per_file): | |
self.example_address_book = example_address_book | |
self.output_dir = output_dir | |
self.parse_func = parse_func | |
self.worker_num = worker_num | |
self.group_percent = group_percent | |
self.max_per_file = max_per_file | |
self.examples_numbers = {} | |
self.address_book_of_group = {} | |
self.tasks = [] | |
self.workers = [] | |
self.source_list = [] | |
self.batch_list = {} | |
self.examples_numbers['train'] = int(len(example_address_book) * group_percent['train']) | |
self.examples_numbers['val'] = int(len(example_address_book) * group_percent['val']) | |
self.examples_numbers['test'] = len(example_address_book) - self.examples_numbers['train'] - self.examples_numbers['val'] | |
for group in ['train', 'val', 'test']: | |
self.address_book_of_group[group] = random.sample(example_address_book, self.examples_numbers[group]) | |
self.batch_list[group] = [] | |
def split_tasks(self): | |
# loop examples and split into tasks | |
for group in ['train', 'val', 'test']: | |
task_id = 0 | |
start = 0 | |
end = min(self.max_per_file, len(self.address_book_of_group[group])) | |
while end < len(self.address_book_of_group[group]): | |
if task_id == 0: | |
task_name = group + '.tfrecords' | |
else: | |
task_name = group + '_' + str(task_id) + '.tfrecords' | |
new_task = Task(group, self.address_book_of_group[group][start:end], name=task_name) | |
self.tasks.append(new_task) | |
self.batch_list[group].append(task_name) | |
start += self.max_per_file | |
end += self.max_per_file | |
task_id += 1 | |
if task_id == 0: | |
task_name = group + '.tfrecords' | |
else: | |
task_name = group + '_' + str(task_id) + '.tfrecords' | |
end = len(self.address_book_of_group[group]) | |
new_task = Task(group, self.address_book_of_group[group][start:end], name=task_name) | |
self.tasks.append(new_task) | |
self.batch_list[group].append(task_name) | |
def run(self): | |
self.split_tasks() | |
threadLock = threading.Lock() | |
# init workers | |
for i in range(0, self.worker_num): | |
worker_name = 'worker'+str(i) | |
new_worker = Worker(worker_name, self.output_dir, self.parse_func, self.examples_numbers, threadLock) | |
self.workers.append(new_worker) | |
# assign tasks to workers | |
for i in range(0, len(self.tasks)): | |
self.workers[i % self.worker_num].add_task(self.tasks[i]) | |
# start workers | |
for worker in self.workers: | |
worker.start() | |
# wait workers | |
for worker in self.workers: | |
worker.join() | |
def dump_meta_json(self, output_dir=None, extra_params=None): | |
if not output_dir: | |
output_dir = self.output_dir | |
with open(os.path.join(output_dir, 'meta.json'), 'w') as f: | |
content = { | |
'num_examples': self.examples_numbers, | |
'train_batch_list': self.batch_list['train'], | |
'val_batch_list': self.batch_list['val'], | |
'test_batch_list': self.batch_list['test'], | |
} | |
if extra_params: | |
for k,v in extra_params.items(): | |
content[k] = v | |
json.dump(content, f) | |
# define example gather and parser | |
def walkdir(rootdir): | |
for node in os.listdir(rootdir): | |
path = os.path.join(rootdir, node) | |
if os.path.isdir(path): | |
walkdir(path) | |
if os.path.splitext(node)[1] in ['.jpg', '.png', '.jpeg']: | |
example_address_book.append(path) | |
TARGET_SIZE = 512 | |
def parse_example(file_path): | |
filename = os.path.basename(file_path) | |
index = os.path.splitext(filename)[0] | |
mask_path = os.path.join(mask_dir, index + '.png') | |
if not os.path.exists(mask_path): | |
raise IOError(str(mask_path) + ' not found.') | |
image = Image.open(file_path) | |
image = image.convert('RGB') | |
image = image.resize((TARGET_SIZE, TARGET_SIZE)) | |
image_data = np.array(image).tobytes() | |
mask_img = Image.open(mask_path) | |
mask_img = mask_img.resize((TARGET_SIZE, TARGET_SIZE)) | |
mask = np.uint8(np.array(mask_img)) | |
mask_data = mask.tobytes() | |
portrait_mask = np.uint8(mask>0) | |
portrait_mask_data = portrait_mask.tobytes() | |
tf_example = tf.train.Example(features=tf.train.Features(feature={ | |
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data])), | |
'mask': tf.train.Feature(bytes_list=tf.train.BytesList(value=[mask_data])), | |
'portrait_mask': tf.train.Feature(bytes_list=tf.train.BytesList(value=[portrait_mask_data])) | |
})) | |
return tf_example | |
# config params and start threads | |
DATASET_PARAMS_DICT = { | |
"percent": { | |
"train": 0.95, | |
"val": 0.05, | |
"test": 0.0 | |
}, | |
"max_per_file": 3000, | |
"sample_per_class_max": 0, # whether balance sample cnt for classes and the max cnt | |
"parallel_num": 10 #30 | |
} | |
PARAMS = edict(DATASET_PARAMS_DICT) | |
ROOT_DIR = os.path.join("/home/dl/i19tdata/proj_portrat2anime", 'CelebAMask-HQ') | |
face_dir = os.path.join(ROOT_DIR, 'CelebA-HQ-img') | |
mask_dir = os.path.join(ROOT_DIR, 'CelebAMask-total-mask') | |
target_dir = os.path.join(ROOT_DIR, 'facial_segmentation_tfrecords_' + get_time_str()) | |
touch_and_clean_dir(target_dir) | |
example_address_book = [] | |
walkdir(face_dir) | |
random.shuffle(example_address_book) | |
farm = ParallelFarm(example_address_book, | |
output_dir = target_dir, | |
parse_func=parse_example, | |
worker_num=PARAMS.parallel_num, | |
group_percent=PARAMS.percent, | |
max_per_file=PARAMS.max_per_file) | |
farm.run() | |
farm.dump_meta_json() | |
#Rename those tfrecords | |
#====================================== | |
import os | |
import shutil | |
def rename_tfrecord_with_meta(meta_json_path): | |
assert(os.path.exists(meta_json_path)) | |
base_dir = os.path.abspath(os.path.join(meta_json_path, '..')) | |
meta = {} | |
with open(meta_json_path, 'r') as f: | |
line = f.read() | |
meta = json.loads(line) | |
group_tag = ['train', 'val', 'test'] | |
shards = {} | |
for group in group_tag: | |
group_key = group+'_batch_list' | |
shards[group] = len(meta[group_key]) | |
new_batch_filenames = [] | |
for i in range(shards[group]): | |
if i==0: | |
original_name = group+'.tfrecords' | |
else: | |
original_name = group+'_'+str(i)+'.tfrecords' | |
new_name = '{}-{:05d}-of-{:05d}.tfrecord'.format(group, i, shards[group]) | |
new_batch_filenames.append(new_name) | |
shutil.move(os.path.join(base_dir, original_name), os.path.join(base_dir, new_name)) | |
meta[group_key] = new_batch_filenames | |
# write new meta back | |
with open(meta_json_path, 'w') as fout: | |
fout.write(json.dumps(meta, indent=4)) | |
rename_tfrecord_with_meta(os.path.join(TEST_DIR, 'test_meta.json')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment