Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
[parallel tfrecords] create tfrecords with multi threads #ml #tf
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import re
import sys
import time
import json
import shutil
import random
import hashlib
import subprocess
import threading
import numpy as np
import pandas as pd
from datetime import timedelta
from tqdm import tqdm, tqdm_notebook, tnrange
import tensorflow as tf
from tensorflow.python.util import compat
from easydict import EasyDict as edict
from PIL import Image
import matplotlib.pyplot as plt
import h5py
%matplotlib inline
plt.rcParams['figure.figsize'] = (16.0, 4.0)
def get_time_str():
return time.strftime("%m%d_%H%M%S", time.localtime())
def touch_dir(dir):
if not os.path.exists(dir):
os.makedirs(dir)
def touch_and_clean_dir(dir):
if os.path.exists(dir):
shutil.rmtree(dir)
os.makedirs(dir)
# define helper workers
class Task(object):
def __init__(self, group, address_book, name):
self.group = group
self.address_book = address_book
self.count = len(address_book)
self.name = name
worker_tasks = threading.local()
class Worker(threading.Thread):
def __init__(self, name, output_dir, parse_func, examples_numbers, threadLock, tasks=None):
super(Worker, self).__init__()
self.name = name
self.output_dir = output_dir
self.tasks = tasks if tasks else []
self.parse_func = parse_func
self.examples_numbers = examples_numbers
self.threadLock = threadLock
def add_task(self, task):
self.tasks.append(task)
def run(self) -> None:
print("Starting worker({})".format(self.name))
for task in self.tasks:
tf_writer = tf.python_io.TFRecordWriter(os.path.join(self.output_dir, task.name))
failed = 0
for file_path in tqdm_notebook(task.address_book, desc=task.name):
try:
tf_example = self.parse_func(file_path)
if tf_example:
tf_writer.write(tf_example.SerializeToString())
else:
failed += 1
except Exception as e:
print('Exception when parse: ', e)
failed += 1
tf_writer.close()
if failed>0:
self.threadLock.acquire()
print("{} files failed and skipped".format(str(failed)))
# TODO how to modify this example_numbers while its located in ParalledFarm
self.examples_numbers[task.group] -= failed
self.threadLock.release()
class ParallelFarm():
def __init__(self, example_address_book, output_dir, parse_func, worker_num, group_percent, max_per_file):
self.example_address_book = example_address_book
self.output_dir = output_dir
self.parse_func = parse_func
self.worker_num = worker_num
self.group_percent = group_percent
self.max_per_file = max_per_file
self.examples_numbers = {}
self.address_book_of_group = {}
self.tasks = []
self.workers = []
self.source_list = []
self.batch_list = {}
self.examples_numbers['train'] = int(len(example_address_book) * group_percent['train'])
self.examples_numbers['val'] = int(len(example_address_book) * group_percent['val'])
self.examples_numbers['test'] = len(example_address_book) - self.examples_numbers['train'] - self.examples_numbers['val']
for group in ['train', 'val', 'test']:
self.address_book_of_group[group] = random.sample(example_address_book, self.examples_numbers[group])
self.batch_list[group] = []
def split_tasks(self):
# loop examples and split into tasks
for group in ['train', 'val', 'test']:
task_id = 0
start = 0
end = min(self.max_per_file, len(self.address_book_of_group[group]))
while end < len(self.address_book_of_group[group]):
if task_id == 0:
task_name = group + '.tfrecords'
else:
task_name = group + '_' + str(task_id) + '.tfrecords'
new_task = Task(group, self.address_book_of_group[group][start:end], name=task_name)
self.tasks.append(new_task)
self.batch_list[group].append(task_name)
start += self.max_per_file
end += self.max_per_file
task_id += 1
if task_id == 0:
task_name = group + '.tfrecords'
else:
task_name = group + '_' + str(task_id) + '.tfrecords'
end = len(self.address_book_of_group[group])
new_task = Task(group, self.address_book_of_group[group][start:end], name=task_name)
self.tasks.append(new_task)
self.batch_list[group].append(task_name)
def run(self):
self.split_tasks()
threadLock = threading.Lock()
# init workers
for i in range(0, self.worker_num):
worker_name = 'worker'+str(i)
new_worker = Worker(worker_name, self.output_dir, self.parse_func, self.examples_numbers, threadLock)
self.workers.append(new_worker)
# assign tasks to workers
for i in range(0, len(self.tasks)):
self.workers[i % self.worker_num].add_task(self.tasks[i])
# start workers
for worker in self.workers:
worker.start()
# wait workers
for worker in self.workers:
worker.join()
def dump_meta_json(self, output_dir=None, extra_params=None):
if not output_dir:
output_dir = self.output_dir
with open(os.path.join(output_dir, 'meta.json'), 'w') as f:
content = {
'num_examples': self.examples_numbers,
'train_batch_list': self.batch_list['train'],
'val_batch_list': self.batch_list['val'],
'test_batch_list': self.batch_list['test'],
}
if extra_params:
for k,v in extra_params.items():
content[k] = v
json.dump(content, f)
# define example gather and parser
def walkdir(rootdir):
for node in os.listdir(rootdir):
path = os.path.join(rootdir, node)
if os.path.isdir(path):
walkdir(path)
if os.path.splitext(node)[1] in ['.jpg', '.png', '.jpeg']:
example_address_book.append(path)
TARGET_SIZE = 512
def parse_example(file_path):
filename = os.path.basename(file_path)
index = os.path.splitext(filename)[0]
mask_path = os.path.join(mask_dir, index + '.png')
if not os.path.exists(mask_path):
raise IOError(str(mask_path) + ' not found.')
image = Image.open(file_path)
image = image.convert('RGB')
image = image.resize((TARGET_SIZE, TARGET_SIZE))
image_data = np.array(image).tobytes()
mask_img = Image.open(mask_path)
mask_img = mask_img.resize((TARGET_SIZE, TARGET_SIZE))
mask = np.uint8(np.array(mask_img))
mask_data = mask.tobytes()
portrait_mask = np.uint8(mask>0)
portrait_mask_data = portrait_mask.tobytes()
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data])),
'mask': tf.train.Feature(bytes_list=tf.train.BytesList(value=[mask_data])),
'portrait_mask': tf.train.Feature(bytes_list=tf.train.BytesList(value=[portrait_mask_data]))
}))
return tf_example
# config params and start threads
DATASET_PARAMS_DICT = {
"percent": {
"train": 0.95,
"val": 0.05,
"test": 0.0
},
"max_per_file": 3000,
"sample_per_class_max": 0, # whether balance sample cnt for classes and the max cnt
"parallel_num": 10 #30
}
PARAMS = edict(DATASET_PARAMS_DICT)
ROOT_DIR = os.path.join("/home/dl/i19tdata/proj_portrat2anime", 'CelebAMask-HQ')
face_dir = os.path.join(ROOT_DIR, 'CelebA-HQ-img')
mask_dir = os.path.join(ROOT_DIR, 'CelebAMask-total-mask')
target_dir = os.path.join(ROOT_DIR, 'facial_segmentation_tfrecords_' + get_time_str())
touch_and_clean_dir(target_dir)
example_address_book = []
walkdir(face_dir)
random.shuffle(example_address_book)
farm = ParallelFarm(example_address_book,
output_dir = target_dir,
parse_func=parse_example,
worker_num=PARAMS.parallel_num,
group_percent=PARAMS.percent,
max_per_file=PARAMS.max_per_file)
farm.run()
farm.dump_meta_json()
#Rename those tfrecords
#======================================
import os
import shutil
def rename_tfrecord_with_meta(meta_json_path):
assert(os.path.exists(meta_json_path))
base_dir = os.path.abspath(os.path.join(meta_json_path, '..'))
meta = {}
with open(meta_json_path, 'r') as f:
line = f.read()
meta = json.loads(line)
group_tag = ['train', 'val', 'test']
shards = {}
for group in group_tag:
group_key = group+'_batch_list'
shards[group] = len(meta[group_key])
new_batch_filenames = []
for i in range(shards[group]):
if i==0:
original_name = group+'.tfrecords'
else:
original_name = group+'_'+str(i)+'.tfrecords'
new_name = '{}-{:05d}-of-{:05d}.tfrecord'.format(group, i, shards[group])
new_batch_filenames.append(new_name)
shutil.move(os.path.join(base_dir, original_name), os.path.join(base_dir, new_name))
meta[group_key] = new_batch_filenames
# write new meta back
with open(meta_json_path, 'w') as fout:
fout.write(json.dumps(meta, indent=4))
rename_tfrecord_with_meta(os.path.join(TEST_DIR, 'test_meta.json'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.