Last active
September 16, 2018 00:45
-
-
Save trcook/e5e3ed6dd0306255cc96b5bc6bc55314 to your computer and use it in GitHub Desktop.
tf record writer for image data. files for writing tfrecords and reading tfrecords into tensorflow datasets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
class GWData(object): | |
""" | |
# usage: make a dataset by passing in a tfrecord. Format will need to line up with output produced in tf_record_writer.py | |
# For a more easily workable example: https://gist.github.com/trcook/9fc8698cf7dc848a953f8e7a7e5f1aad | |
:Example: | |
:: | |
dataset=GWData('./output.tfrecord') | |
val_dataset=GWData('./validate.tfrecord') | |
#make initializer that can be hot-swapped in keras | |
it=tf.data.Iterator.from_structure(dataset.output_types,dataset.output_shapes) | |
# make initializers for the iterator for dataset and val_dataset | |
dataset.mk_init(it) | |
val_dataset.mk_init(it) | |
# get handle for next batch from the iterator | |
next_item=it.get_next() | |
# Produce X and Y handles that can be used as targets for keras | |
X=next_item['X'] | |
Y=next_item['Y'] | |
# Into Keras: | |
from keras import backend as K | |
import keras as k | |
from keras.layers import * | |
inputs=Input(tensor=X) | |
labs=Input(tensor=Y) | |
net=Dense(.....)(inputs) | |
# ... rest of model goes here | |
output=Dense(....)(net) | |
model=k.Model(inputs,output) | |
# setup loss, etc, anything else you do before normally calling to fit: | |
model.add_loss(loss) | |
model.compile('adam') | |
# initialize the iterator to the training data: | |
sess.run(dataset.init_it) | |
model.fit(.....) | |
# now, hotswap dataset out for val_dataset: | |
sess.run(val_dataset.init_it) | |
model.evaluate(...) | |
# This makes it easy enough to grab the outputs too: | |
sess.run(output) | |
""" | |
def __init__(self,the_file:str): | |
self.filename=the_file | |
self.dataset=self.mk_data(self.filename) | |
self.output_types=self.dataset.output_types | |
self.output_shapes=self.dataset.output_shapes | |
self.it_init=None | |
def parse_label(self,lab:tf.string): | |
""" NOT IMPLEMENTED YET""" | |
# probably put tensor through something like word-to-vec, but that may make more sense *before* we pack into tfrecord. | |
return lab | |
def parse_function(self, ex: tf.train.Example) -> dict: | |
""" | |
This is the function that is mapped over the examples in the tfrecord file to parse an example back into a usable set of tensors. | |
""" | |
# define a feature map that identifies the type of feature and dtype. | |
# The keys in the dict should correspond to the names of features (i.e. keys) in the example. | |
feature_map = { | |
"image": tf.FixedLenFeature((),tf.string), | |
"label":tf.FixedLenFeature((),tf.string), | |
"X": tf.FixedLenFeature((),tf.float32), | |
"Y":tf.FixedLenFeature((),tf.float32)} | |
# parse the example -- tensorflow applies the dict to the example. | |
# The resulting dict-like will have tensors corresponding to the values unpacked from the example. | |
pex = tf.parse_single_example(ex,feature_map) | |
# re-cast our dimension parameters to integers | |
X=tf.cast(pex['X'],tf.int32) | |
Y=tf.cast(pex['Y'],tf.int32) | |
D=tf.constant(3) | |
image_dims=tf.stack([X,Y,D]) | |
# decode image data from byte-encoded back into integers. | |
# It has a type of tf.int8 because the data is byte-encoded -- 1 byte is 8-bits, hence a 8-bit integer. | |
img=tf.decode_raw(pex['image'],tf.int8) | |
# cast the image data back to float32 -- we do this because its easier | |
# to deal with data in a model if the inputs are floats (less type errors). | |
img=tf.cast(img,tf.float32) | |
# reshape the image data back to its original shape | |
img=tf.reshape(img,image_dims) | |
# Apply a parsing routine to the label data. If not implemented, | |
label=self.parse_label(pex['label']) | |
return {'X':img,'Y':label} | |
def mk_data(self,the_file:"str or list"): | |
# add some logic so we can add one file or a list of files | |
if not isinstance(the_file,list): | |
the_file=[the_file] | |
# point the dataset at the tfrecord we created | |
dataset=tf.data.TFRecordDataset(the_file) | |
# Parse the record into tensors. | |
dataset = dataset.map(self.parse_function) | |
# Shuffle the dataset | |
dataset = dataset.shuffle(buffer_size=1) | |
# Repeat the input indefinitly | |
dataset = dataset.repeat() | |
# Generate batches | |
dataset = dataset.batch(3) | |
return dataset | |
def mk_init(self,it:tf.data.Iterator): | |
''' | |
This function is what creates the button that lets us switch the source feeding an iterator. | |
Once run, the class instance will populate a property called in_init. | |
:Example: | |
:: | |
import tensorflow as tf | |
dat=GWData('file.tfrecord') | |
it=tf.data.Iterator.from_structure(dataset.output_types,dataset.output_shapes) | |
dat.mk_init(it) | |
sess=tf.Session() | |
# point iterator 'it' at dataset in 'dat': | |
sess.run(dat.it_init) | |
# tell iterator 'it' to get a (batch) example from the dataset in 'dat' | |
sess.run(it.get_next()) | |
''' | |
dataset=self.dataset | |
#generate initializer for the iterator that is specific to *dataset* | |
self.it_init=it.make_initializer(dataset) | |
import numpy as np | |
np.concatenate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python | |
# This file writes records into tensorflow | |
import tensorflow as tf | |
import cv2 | |
import glob | |
import os | |
import argparse | |
from tqdm import tqdm | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-o",'--output',help='output name', default='output.tfrecord',dest='output') | |
parser.add_argument("-g","--glob",help='glob pattern to use for finding images',default="./**/*.[Jj][Pp][Gg]",dest='globber') | |
def parse_path(x): | |
raise NotImplementedError | |
def get_image(img,sized=(1800,1200)): | |
x=cv2.imread(img) | |
xs=x.shape | |
i=sized[0]-xs[0] | |
j=sized[1]-xs[1] | |
if i>0: | |
x=np.pad(x,((0,0),(i//2,i-i//2),(0,0)),'constant') | |
if i<0: | |
end=xs[0]-(i//2) | |
start=i//2 | |
x=x[start:end,:,:] | |
if j>0: | |
x=np.pad(x,((j//2,j-j//2),(0,0),(0,0)),'constant') | |
if j<0: | |
end=xs[0]-(j//2) | |
start=j//2 | |
x=x[start:end,:,:] | |
x=cv2.cvtColor(x,cv2.COLOR_BGR2RGB) | |
return x | |
int_feature=lambda x: tf.train.Feature(int64_list=tf.train.Int64List(value=[x])) | |
byte_feature=lambda x: tf.train.Feature(bytes_list=tf.train.BytesList(value=[x])) | |
float_feature=lambda x: tf.train.Feature(float_list=tf.train.FloatList(value=[x])) | |
def make_example_from_image(img,path): | |
x,y,z=img.shape | |
payload={ | |
"image":byte_feature(img.tobytes()), | |
"label":byte_feature(bytes(path,'utf-8')), | |
"X":float_feature(x), | |
"Y":float_feature(y) | |
} | |
ex=tf.train.Example(features=tf.train.Features(feature=payload)) | |
return ex | |
def record_from_image_paths(files,output_file): | |
w = tf.python_io.TFRecordWriter(output_file) | |
for path in tqdm(files): | |
img=get_image(path) | |
ex=make_example_from_image(img,path) | |
w.write(ex.SerializeToString()) | |
w.close() | |
if __name__=='__main__': | |
args=parser.parse_args() | |
output=args.output | |
globber=args.globber | |
print(globber) | |
files=glob.glob(globber) | |
print(files) | |
record_from_image_paths(files,output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment