Last active
April 21, 2019 14:03
-
-
Save jinyu121/a492d272d5293cff544b7578cd52aa27 to your computer and use it in GitHub Desktop.
TensorFlow Slim Classification Data Convert and Read
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import os | |
import tensorflow as tf | |
import tensorflow.contrib.slim as slim | |
def get_dataset(dataset_dir, num_classes, labels_to_names_path=None): | |
num_samples = sum(1 for _ in tf.python_io.tf_record_iterator(dataset_dir)) | |
keys_to_features = { | |
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), | |
'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), | |
'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), | |
} | |
items_to_handlers = { | |
'image': slim.tfexample_decoder.Image(), | |
'label': slim.tfexample_decoder.Tensor('image/class/label'), | |
} | |
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) | |
items_to_descriptions = { | |
'image': 'A color image of varying size.', | |
'label': 'A single integer between 0 and ' + str(num_classes - 1), | |
} | |
labels_to_names = None | |
if labels_to_names_path is not None: | |
with open(labels_to_names_path, 'r') as fd: | |
labels_to_names = {i: line.strip() for i, line in enumerate(fd)} | |
return slim.dataset.Dataset( | |
data_sources=dataset_dir, | |
reader=tf.TFRecordReader, | |
decoder=decoder, | |
num_samples=num_samples, | |
items_to_descriptions=items_to_descriptions, | |
num_classes=num_classes, | |
labels_to_names=labels_to_names) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import os | |
import random | |
import tensorflow as tf | |
from tqdm import tqdm | |
flags = tf.app.flags | |
flags.DEFINE_string('dir', '', '图片路径') | |
flags.DEFINE_string('output', '', '输出路径') | |
flags.DEFINE_integer('channels', 3, '图片通道数量') | |
flags.DEFINE_float('train_scale', 0.8, '训练集比例') | |
flags.DEFINE_bool('reshape', False, '缩放') | |
flags.DEFINE_integer('height', 224, '缩放高度') | |
flags.DEFINE_integer('width', 224, '缩放宽度') | |
FLAGS = flags.FLAGS | |
def int64_feature(value): | |
"""Wrapper for inserting int64 features into Example proto. | |
""" | |
if not isinstance(value, list): | |
value = [value] | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) | |
def float_feature(value): | |
"""Wrapper for inserting float features into Example proto. | |
""" | |
if not isinstance(value, list): | |
value = [value] | |
return tf.train.Feature(float_list=tf.train.FloatList(value=value)) | |
def bytes_feature(value): | |
"""Wrapper for inserting bytes features into Example proto. | |
""" | |
if not isinstance(value, list): | |
value = [value] | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) | |
def image_to_tfexample(image_data, image_format, height, width, class_id): | |
return tf.train.Example( | |
features=tf.train.Features( | |
feature={ | |
'image/encoded': bytes_feature(image_data), | |
'image/format': bytes_feature(image_format), | |
'image/class/label': int64_feature(class_id), | |
'image/height': int64_feature(height), | |
'image/width': int64_feature(width), | |
} | |
) | |
) | |
def convert_to_tfrecord(save_file_name, txt_file_name, files, choices): | |
files = [files[i] for i in choices] | |
data_input = tf.placeholder(dtype=tf.string) | |
process = tf.image.decode_jpeg(data_input, channels=FLAGS.channels) | |
if FLAGS.reshape: | |
process = tf.image.resize_images(process, [FLAGS.height, FLAGS.width]) | |
with tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) as sess: | |
with tf.python_io.TFRecordWriter(save_file_name) as writer: | |
with open(txt_file_name, 'w') as txt: | |
for item in tqdm(files): | |
txt.write("{} {}\n".format(item[0], item[1])) | |
image_data = tf.gfile.FastGFile(item[0], 'rb').read() | |
image = sess.run(process, feed_dict={data_input: image_data}) | |
h, w = image.shape[0:2] | |
example = image_to_tfexample(image_data, b'jpg', h, w, item[1]) | |
writer.write(example.SerializeToString()) | |
def get_folders(path): | |
for parent, dirnames, filenames in os.walk(path): | |
return sorted(dirnames) | |
def get_files(path): | |
for parent, dirnames, filenames in os.walk(path): | |
return sorted(filenames) | |
def get_descrip(path, label): | |
num = len(label) | |
for i in random.sample(range(0, num), num): | |
yield path[i], label[i] | |
def get_classes(path): | |
classes = get_folders(path) | |
return dict(zip(classes, range(len(classes)))) | |
def main(args): | |
files = [] | |
# 路径定义 | |
description_file_name = os.path.join(FLAGS.output, 'list.txt') | |
label_file_name = os.path.join(FLAGS.output, 'label.txt') | |
train_file_name = os.path.join(FLAGS.output, 'list_train.txt') | |
test_file_name = os.path.join(FLAGS.output, 'list_test.txt') | |
tf_train_file_name = os.path.join(FLAGS.output, 'train.tfrecord') | |
tf_test_file_name = os.path.join(FLAGS.output, 'test.tfrecord') | |
# 使用文件夹名作为类别名 | |
classes = get_classes(FLAGS.dir) | |
# 写入Label文件 | |
with open(label_file_name, 'w') as ofile: | |
for clazz in classes: | |
ofile.write("{}\n".format(clazz)) | |
# 获取所有图片文件路径和类别 | |
for clazz in classes.items(): | |
fs = get_files(os.path.join(FLAGS.dir, clazz[0])) | |
fs = [os.path.join(FLAGS.dir, clazz[0], fx) for fx in fs] | |
files.extend(zip(fs, [clazz[1] + 1] * len(fs))) | |
# 总文件写入 | |
with open(description_file_name, 'w') as ofile: | |
for item in files: | |
ofile.write("{} {}\n".format(item[0], item[1])) | |
# 统计数量 | |
num_samples = len(files) | |
num_train = int(num_samples * FLAGS.train_scale) | |
choice = random.sample(range(0, num_samples), num_samples) | |
# 下面开始TensorFlow转换 | |
convert_to_tfrecord(tf_train_file_name, train_file_name, files, choice[0:num_train]) | |
convert_to_tfrecord(tf_test_file_name, test_file_name, files, choice[num_train:]) | |
if "__main__" == __name__: | |
tf.app.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import os | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.contrib.data import Dataset | |
from tensorflow.python.framework import dtypes | |
from tensorflow.python.framework.ops import convert_to_tensor | |
class ImageDataGenerator(object): | |
def __init__(self, txt_file, batch_size, num_classes, shuffle=True, | |
image_size=None, one_hot=True): | |
self.image_size = image_size | |
self.one_hot = one_hot | |
self.txt_file = txt_file | |
self.num_classes = num_classes | |
buffer_size = 100 * batch_size | |
# retrieve the data from the text file | |
self._read_txt_file() | |
# number of samples in the dataset | |
self.data_size = len(self.labels) | |
# initial shuffling of the file and label lists (together!) | |
# if shuffle: | |
# self._shuffle_lists() | |
# convert lists to TF tensor | |
self.img_paths = convert_to_tensor(self.img_paths, dtype=dtypes.string) | |
self.labels = convert_to_tensor(self.labels, dtype=dtypes.int32) | |
# create dataset | |
data = Dataset.from_tensor_slices((self.img_paths, self.labels)) | |
data = data.repeat() | |
data = data.map( | |
self._parse_function, | |
num_threads=8, | |
output_buffer_size=buffer_size | |
) | |
# shuffle the first `buffer_size` elements of the dataset | |
if shuffle: | |
data = data.shuffle(buffer_size=buffer_size) | |
# create a new dataset with batches of images | |
self.data = data.batch(batch_size) | |
def _read_txt_file(self): | |
"""Read the content of the text file and store it into lists.""" | |
self.img_paths = [] | |
self.labels = [] | |
for line in open(self.txt_file, 'r'): | |
items = line.split(' ') | |
self.img_paths.append(items[0]) | |
self.labels.append(int(items[1])) | |
def _shuffle_lists(self): | |
"""Conjoined shuffling of the list of paths and labels.""" | |
permutation = np.random.permutation(self.data_size) | |
self.img_paths = [self.img_paths[x] for x in permutation] | |
self.labels = [self.labels[x] for x in permutation] | |
def _parse_function(self, filename, label): | |
"""Input parser for samples of the training set.""" | |
# convert label number into one-hot-encoding | |
if self.one_hot: | |
lab = tf.one_hot(label, self.num_classes) | |
else: | |
lab = label | |
# load and preprocess the image | |
pipe = tf.read_file(filename) | |
pipe = tf.image.decode_jpeg(pipe, channels=3) | |
if self.image_size is not None: | |
pipe = tf.image.resize_images(pipe, self.image_size) | |
# RGB -> BGR | |
# pipe = pipe[:, :, ::-1] | |
return pipe, lab |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment