Skip to content

Instantly share code, notes, and snippets.

@schipiga
Last active February 19, 2020 03:49
Show Gist options
  • Save schipiga/74cff36c1f1ea7d7a2ad45682f4e7047 to your computer and use it in GitHub Desktop.
Save schipiga/74cff36c1f1ea7d7a2ad45682f4e7047 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import numpy as np
import uuid
__all__ = ['to_tfrecords']
def to_tfrecords(dataframe, dir):
schema = get_schema(dataframe)
tfrecords = get_tfrecords(dataframe, schema)
tfrecords = split_by_size(tfrecords)
write_tfrecords(tfrecords)
def bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy()
if isinstance(value, str):
value = str.encode(value)
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def write_tfrecords(tfrecords):
uid = str(uuid.uuid4())
options = tf.io.TFRecordOptions(
compression_type='GZIP',
compression_level=9,
)
for idx, chunk in enumerate(tfrecords):
file_path = f'part-{str(idx).zfill(5)}-{uid}.tfrecords'
with tf.io.TFRecordWriter(file_path, options=options) as writer:
for item in chunk:
writer.write(item.SerializeToString())
def get_tfrecords(dataframe, schema):
for idx, row in dataframe.iterrows():
features = {}
feature_lists = {}
for col, val in row.items():
f = schema[col](val)
if type(f) is tf.train.FeatureList:
feature_lists[col] = f
if type(f) is tf.train.Feature:
features[col] = f
context = tf.train.Features(feature=features)
if feature_lists:
ex = tf.train.SequenceExample(
context=context,
feature_lists=tf.train.FeatureLists(feature_list=feature_lists))
else:
ex = tf.train.Example(features=context)
yield ex
def get_feature_func(_type):
if _type in (str, np.str):
return bytes_feature
if _type in (int, np.int, np.int0, np.int8, np.int16, np.int32, np.int64):
return int64_feature
if _type in (float, np.float, np.float16, np.float32, np.float64, np.float128):
return float_feature
raise Exception(f'Unsupported type {_type!r}')
def get_schema(dataframe):
columns = dataframe.columns.to_list()
schema = {}
row = dataframe.iloc[0]
for col in columns:
col_type = type(row[col])
if col_type in (list, np.ndarray):
item_type = type(row[col][0])
schema[col] = (lambda f: lambda x: \
tf.train.FeatureList(feature=[f(i) for i in x]))(get_feature_func(item_type))
else:
schema[col] = (lambda f: lambda x: f(x))(get_feature_func(col_type))
return schema
def split_by_size(tfrecords, max_mb=50):
max_size = max_mb * 1024 * 1024
cur_size = 0
item = []
for row in tfrecords:
if cur_size + row.ByteSize() > max_size:
yield item
item = []
cur_size = 0
item.append(row)
cur_size = cur_size + row.ByteSize()
yield item
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment