Skip to content

Instantly share code, notes, and snippets.

@bveeramani
Last active August 10, 2022 21:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bveeramani/24d74097821e4be19d1e586ea267302e to your computer and use it in GitHub Desktop.
Save bveeramani/24d74097821e4be19d1e586ea267302e to your computer and use it in GitHub Desktop.
from typing import Dict
import pandas as pd
import pyarrow
import ray
import tensorflow as tf
from ray.data.block import Block
from ray.data.datasource.file_based_datasource import FileBasedDatasource
def main():
df = pd.DataFrame({"foo": ["cat", "dog", "elephant"], "bar": [0, 1, 2]})
write_dataframe(df, "foobar.tfrecords")
features = {
'foo': tf.io.FixedLenFeature([], tf.string),
'bar': tf.io.FixedLenFeature([], tf.float32)
}
dataset = ray.data.read_datasource(TFRecordsDatasource(), paths=["foobar.tfrecords"], features=features)
print(dataset)
print(dataset.take())
def write_dataframe(df: pd.DataFrame, path: str) -> None:
"""Write a Pandas DataFrame to a `tfrecords` file."""
examples = []
for _, row in df.iterrows():
features = tf.train.Features(feature={
"foo": tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(row["foo"], "utf-8")])),
"bar": tf.train.Feature(float_list=tf.train.FloatList(value=[row["bar"]]))
})
example = tf.train.Example(features=features)
examples.append(example)
with tf.io.TFRecordWriter(path=path) as writer:
for example in examples:
writer.write(example.SerializeToString())
class TFRecordsDatasource(FileBasedDatasource):
_FILE_EXTENSION = "tfrecords"
def _read_file(
self, f: "pyarrow.NativeFile", path: str, features: Dict[str, tf.io.FixedLenFeature], **reader_args
) -> Block:
dataset = tf.data.TFRecordDataset([path])
dataset = dataset.map(lambda serialized: tf.io.parse_single_example(serialized, features))
foo = [record["foo"].numpy().decode("utf-8") for record in dataset]
bar = [float(record["bar"]) for record in dataset]
return pd.DataFrame({"foo": foo, "bar": bar})
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment