Skip to content

Instantly share code, notes, and snippets.

@BryanCutler
Last active February 25, 2020 18:24
TensorFlow Arrow Blog Part 7 - Model Training Local Dataset
def make_local_dataset(filename):
"""Make a TensorFlow Arrow Dataset that reads from a local CSV file."""
# Read the local file and get a record batch iterator
batch_iter = read_and_process(filename)
# Create the Arrow Dataset as a stream from local iterator of record batches
ds = arrow_io.ArrowStreamDataset.from_record_batches(
batch_iter,
output_types=(tf.int64, tf.float64, tf.float64),
output_shapes=(tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([])),
batch_mode='auto',
record_batch_iter_factory=partial(read_and_process, filename))
# Map the dataset to combine feature columns to single tensor
ds = ds.map(lambda l, x0, x1: (tf.stack([x0, x1], axis=1), l))
return ds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment