Last active
May 23, 2019 01:51
-
-
Save usmcamp0811/8af48beb3f383165c06b8ee71588d498 to your computer and use it in GitHub Desktop.
Takes many CSVs and converts to a TFRecord file, then opens the TFRecords file and outputs training data
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import pandas as pd | |
import numpy as np | |
from tqdm import tqdm | |
import os | |
def _int64_feature(value): | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) | |
def _float_feature(value): | |
return tf.train.Feature(float_list=tf.train.FloatList(value=value)) | |
def _bytes_feature(value): | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | |
def make_q_list(filepathlist, filetype): | |
filepathlist = filepathlist | |
filepaths = [] | |
labels = [] | |
for path in filepathlist: | |
data_files = os.listdir(path) | |
for data in data_files: | |
if data.endswith(filetype): | |
data_file = os.path.join(path, data) | |
data_file = data_file | |
data_label = os.path.basename(os.path.normpath(path)) | |
filepaths.append(data_file) | |
labels.append(data_label) | |
return filepaths, labels | |
def tables_to_TF(queue_list, tf_filename, file_type='csv'): | |
# Target variable needs to be the last column of data | |
filepath = os.path.join(tf_filename) | |
print('Writing', filepath) | |
writer = tf.python_io.TFRecordWriter(tf_filename) | |
for file in tqdm(queue_list): | |
if file_type == 'csv': | |
data = pd.read_csv(file).values | |
elif file_type == 'hdf': | |
data = pd.read_hdf(file).values | |
else: | |
print(file_type, 'is not supported at this time...') | |
break | |
for row in data: | |
example = tf.train.Example( | |
features=tf.train.Features(feature={ | |
'x': _float_feature(row) | |
}) | |
) | |
writer.write(example.SerializeToString()) | |
def load_tfrecord(file_name, cols_count): | |
features = {'x': tf.FixedLenFeature([cols_count], tf.float32)} | |
data = [] | |
for s_example in tf.python_io.tf_record_iterator(file_name): | |
example = tf.parse_single_example(s_example, features=features) | |
data.append(tf.expand_dims(example['x'], 0)) | |
return tf.concat(0, data) | |
if __name__ == "__main__": | |
for i in range(10): | |
filename = '/home/mcamp/PycharmProjects/Data/random_csv' + str(i) + '.csv' | |
pd.DataFrame(np.arange(0.0, 5000.0).reshape((100, 50))).to_csv(filename) | |
filepathlist = ['/home/mcamp/PycharmProjects/Data/'] | |
q, _ = make_q_list(filepathlist, '.csv') | |
tffilename = 'Demo_TFR.tfrecords' | |
tables_to_TF(q, tffilename, file_type='csv') | |
data = load_tfrecord(tffilename, 51) | |
example_batch = tf.train.shuffle_batch( | |
[data], batch_size=10, num_threads=1, | |
capacity=2000, enqueue_many=True, | |
# Ensures a minimum amount of shuffling of examples. | |
min_after_dequeue=2, name='Demo_TFR.tfrecords') | |
with tf.Session() as sess: | |
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) | |
sess.run(init_op) | |
coord = tf.train.Coordinator() | |
threads = tf.train.start_queue_runners(coord=coord) | |
#TODO: Add train shuffle batch | |
df = sess.run([example_batch]) | |
print(df) | |
df = pd.DataFrame(df[0]) | |
y = df[df.columns[-1]] | |
X = df[df.columns[:-1]] | |
print(X,y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment