-
-
Save markemus/74ba47d0b58f91d7aa7885341ed3b1b8 to your computer and use it in GitHub Desktop.
"""Easily save tf.data.Datasets as tfrecord files, and restore tfrecords as Datasets. | |
The goal of this module is to create a SIMPLE api to tfrecords that can be used without | |
learning all of the underlying mechanics. | |
Users only need to deal with 2 functions: | |
save(dataset) | |
dataset = load(tfrecord, header) | |
To make this work, we create a .header file for each tfrecord which encodes metadata | |
needed to reconstruct the original dataset. | |
Saving must be done in eager mode, but loading is compatible with both eager and | |
graph execution modes. | |
GOTCHAS: | |
- This module is only compatible with "dictionary-style" datasets {key: val, key2:val2,..., keyN: valN}. | |
- The restored dataset will have the TFRecord dtypes {float32, int64, string} instead of the original | |
tensor dtypes. This is always the case with TFRecord datasets, whether you use this module or not. | |
The original dtypes are stored in the headers if you want to restore them after loading.""" | |
import functools | |
import os | |
import tempfile | |
import numpy as np | |
import yaml | |
import tensorflow as tf | |
# The three encoding functions. | |
def _bytes_feature(value): | |
"""value: list""" | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) | |
def _float_feature(value): | |
"""value: list""" | |
return tf.train.Feature(float_list=tf.train.FloatList(value=value)) | |
def _int64_feature(value): | |
"""value: list""" | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) | |
#TODO use base_type() to ensure consistent conversion. | |
def np_value_to_feature(value): | |
"""Maps dataset values to tf Features. | |
Only numpy types are supported since Datasets only contain tensors. | |
Each datatype should only have one way of being serialized.""" | |
if isinstance(value, np.ndarray): | |
# feature = _bytes_feature(value.tostring()) | |
if np.issubdtype(value.dtype, np.integer): | |
feature = _int64_feature(value.flatten()) | |
elif np.issubdtype(value.dtype, np.float): | |
feature = _float_feature(value.flatten()) | |
elif np.issubdtype(value.dtype, np.bool): | |
feature = _int64_feature(value.flatten()) | |
else: | |
raise TypeError(f"value dtype: {value.dtype} is not recognized.") | |
elif isinstance(value, bytes): | |
feature = _bytes_feature([value]) | |
elif np.issubdtype(type(value), np.integer): | |
feature = _int64_feature([value]) | |
elif np.issubdtype(type(value), np.float): | |
feature = _float_feature([value]) | |
else: | |
raise TypeError(f"value type: {type(value)} is not recognized. value must be a valid Numpy object.") | |
return feature | |
def base_type(dtype): | |
"""Returns the TFRecords allowed type corresponding to dtype.""" | |
int_types = [tf.int8, tf.int16, tf.int32, tf.int64, | |
tf.uint8, tf.uint16, tf.uint32, tf.uint64, | |
tf.qint8, tf.qint16, tf.qint32, | |
tf.bool] | |
float_types = [tf.float16, tf.float32, tf.float64] | |
byte_types = [tf.string, bytes] | |
if dtype in int_types: | |
new_dtype = tf.int64 | |
elif dtype in float_types: | |
new_dtype = tf.float32 | |
elif dtype in byte_types: | |
new_dtype = tf.string | |
else: | |
raise ValueError(f"dtype {dtype} is not a recognized/supported type!") | |
return new_dtype | |
def build_header(dataset): | |
"""Build header dictionary of metadata for the tensors in the dataset. This will be used when loading | |
the tfrecords file to reconstruct the original tensors from the raw data. Shape is stored as an array | |
and dtype is stored as an enumerated value (defined by tensorflow).""" | |
header = {} | |
for key in dataset.element_spec.keys(): | |
header[key] = {"shape": list(dataset.element_spec[key].shape), "dtype": dataset.element_spec[key].dtype.as_datatype_enum} | |
return header | |
def build_feature_desc(header): | |
"""Build feature_desc dictionary for the tensors in the dataset. This will be used to reconstruct Examples | |
from the tfrecords file. | |
Assumes FixedLenFeatures. | |
If you got VarLenFeatures I feel bad for you son, | |
I got 115 problems but a VarLenFeature ain't one.""" | |
feature_desc = {} | |
for key, params in header.items(): | |
feature_desc[key] = tf.io.FixedLenFeature(shape=params["shape"], dtype=base_type(int(params["dtype"]))) | |
return feature_desc | |
def dataset_to_examples(ds): | |
"""Converts a dataset to a dataset of tf.train.Example strings. Each Example is a single observation. | |
WARNING: Only compatible with "dictionary-style" datasets {key: val, key2:val2,..., keyN, valN}. | |
WARNING: Must run in eager mode!""" | |
# TODO handle tuples and flat datasets as well. | |
for x in ds: | |
# Each individual tensor is converted to a known serializable type. | |
features = {key: np_value_to_feature(value.numpy()) for key, value in x.items()} | |
# All features are then packaged into a single Example object. | |
example = tf.train.Example(features=tf.train.Features(feature=features)) | |
yield example.SerializeToString() | |
def save(dataset, tfrecord_path, header_path): | |
"""Saves a flat dataset as a tfrecord file, and builds a header file for reloading as dataset.""" | |
# Header | |
header = build_header(dataset) | |
header_file = open(header_path, "w") | |
yaml.dump(header, stream=header_file) | |
# Dataset | |
ds_examples = tf.data.Dataset.from_generator(lambda: dataset_to_examples(dataset), output_types=tf.string) | |
writer = tf.data.experimental.TFRecordWriter(tfrecord_path) | |
writer.write(ds_examples) | |
# TODO-DECIDE is this yaml loader safe? | |
def load(tfrecord_path, header_path): | |
"""Uses header file to predict the shape and dtypes of tensors for tf.data.""" | |
header_file = open(header_path) | |
header = yaml.load(header_file, Loader=yaml.FullLoader) | |
feature_desc = build_feature_desc(header) | |
parse_func = functools.partial(tf.io.parse_single_example, features=feature_desc) | |
dataset = tf.data.TFRecordDataset(tfrecord_path).map(parse_func) | |
return dataset | |
def test(): | |
"""Test super serial saving and loading. | |
NOTE- test will only work in eager mode due to list() dataset cast.""" | |
savefolder = tempfile.TemporaryDirectory() | |
savepath = os.path.join(savefolder.name, "temp_dataset") | |
tfrecord_path = savepath + ".tfrecord" | |
header_path = savepath + ".header" | |
# Data | |
x = np.linspace(1, 3000, num=3000).reshape(10, 10, 10, 3) | |
y = np.linspace(1, 10, num=10).astype(int) | |
ds = tf.data.Dataset.from_tensor_slices({"image": x, "label": y}) | |
# Run | |
save(ds, tfrecord_path=tfrecord_path, header_path=header_path) | |
new_ds = load(tfrecord_path=tfrecord_path, header_path=header_path) | |
# Test that values were saved and restored | |
assert list(ds)[0]["image"].numpy()[0, 0, 0] == list(new_ds)[0]["image"].numpy()[0, 0, 0] | |
assert list(ds)[0]["label"] != list(new_ds)[0]["label"] | |
# Clean up- folder will disappear on crash as well. | |
savefolder.cleanup() | |
if __name__ == "__main__": | |
test() | |
print("Test passed.") |
As I understand save requires to load all examples into memory. Is there a way to iterate over the dataset and each example as a single tfrecord?
Currently that's true, but I don't think it has to be that way. dataset_to_examples()
can probably be converted to a generator and that should be enough.
Could this be extended to that we save batched examples?
The way I use it is to call super_serial.load().batch(32)
. I am not sure whether TFRecords support a batch dimension, but they probably do- if they do, this can be extended to support it.
BTW this is for TF1.x- if you are using TF2.0 you have to make a few minor changes to some of the import paths, and in build_header()
to use element_spec instead of output_shape. I'll push the changes when I get a chance but you can also debug them yourself pretty easily.
@markemus thanks for your reply. Will look into this.
Also interesting: tensorflow/community#193 which is probably address all the hassles
Ah very cool! I didn't know that was coming, that's excellent news. Frankly I was pretty annoyed that I had to write this at all.
@faroit update for tf2.0 is pushed now.
Super_serial 2.0. Headers are not backwards compatible (but easily convertible).
Changes:
- serialization now uses a generator instead of precompiling the full dataset, so memory usage is now very minimal.
- headers are now human readable and do not serialize Python objects.
- added a test.
save
requires to load all examples into memory. Is there a way to iterate over the dataset and each example as a single tfrecord?