Skip to content

Instantly share code, notes, and snippets.

@BryanCutler
Last active June 28, 2021 16:13
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save BryanCutler/03c2973b8b0baa572d8322173af701b8 to your computer and use it in GitHub Desktop.
Save BryanCutler/03c2973b8b0baa572d8322173af701b8 to your computer and use it in GitHub Desktop.
TensorFlow Keras Model Training Example with Apache Arrow Dataset
from functools import partial
import multiprocessing
import os
import socket
import sys
from sklearn.preprocessing import StandardScaler
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.csv
import tensorflow as tf
tf.enable_eager_execution()
import tensorflow_io.arrow as arrow_io
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
def write_csv(filename, num_records):
"""Generate sample data and write to a CSV file."""
data = {'label': np.random.binomial(1, 0.5, num_records)}
data['x0'] = np.random.randn(num_records) + 5 * data['label']
data['x1'] = np.random.randn(num_records) + 5 * data['label']
df = pd.DataFrame(data)
df.to_csv('sample.csv', index=False)
df = None
def read_and_process(filename):
"""Read the given CSV file and yield processed Arrow batches."""
# Read a CSV file into an Arrow Table with threading enabled and
# set block_size in bytes to break the file into chunks for granularity,
# which determines the number of batches in the resulting pyarrow.Table
opts = pyarrow.csv.ReadOptions(use_threads=True, block_size=4096)
table = pyarrow.csv.read_csv(filename, opts)
# Fit the feature transform
df = table.to_pandas()
scaler = StandardScaler().fit(df[['x0', 'x1']])
# Iterate over batches in the pyarrow.Table and apply processing
for batch in table.to_batches():
df = batch.to_pandas()
# Process the batch and apply feature transform
X_scaled = scaler.transform(df[['x0', 'x1']])
df_scaled = pd.DataFrame({'label': df['label'],
'x0': X_scaled[:, 0],
'x1': X_scaled[:, 1]})
batch_scaled = pa.RecordBatch.from_pandas(df_scaled, preserve_index=False)
yield batch_scaled
def read_and_process_dir(directory):
"""Read a directory of CSV files and yield processed Arrow batches."""
for f in os.listdir(directory):
if f.endswith(".csv"):
filename = os.path.join(directory, f)
for batch in read_and_process(filename):
yield batch
def serve_csv_data(ip_addr, port_num, directory):
"""
Create a socket and serve Arrow record batches as a stream read from the
given directory containing CVS files.
"""
# Create the socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind((ip_addr, port_num))
sock.listen(1)
# Serve forever, each client will get one iteration over data
while True:
conn, _ = sock.accept()
outfile = conn.makefile(mode='wb')
writer = None
try:
# Read directory and iterate over each batch in each file
batch_iter = read_and_process_dir(directory)
for batch in batch_iter:
# Initialize the pyarrow writer on first batch
if writer is None:
writer = pa.RecordBatchStreamWriter(outfile, batch.schema)
# Write the batch to the client stream
writer.write_batch(batch)
# Cleanup client connection
finally:
if writer is not None:
writer.close()
outfile.close()
conn.close()
sock.close()
def start_server_process(host_addr, host_port, serve_dir):
"""Start a process to serve CSV data as an Arrow stream."""
server = multiprocessing.Process(target=serve_csv_data,
args=(host_addr, host_port, serve_dir))
server.daemon = True
server.start()
def make_local_dataset(filename):
"""Make a TensorFlow Arrow Dataset that reads from a local CSV file."""
# Read the local file and get a record batch iterator
batch_iter = read_and_process(filename)
# Create the Arrow Dataset as a stream from local iterator of record batches
ds = arrow_io.ArrowStreamDataset.from_record_batches(
batch_iter,
columns=(0, 1, 2),
output_types=(tf.int64, tf.float64, tf.float64),
output_shapes=(tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([])),
batch_mode='auto',
record_batch_iter_factory=partial(read_and_process, filename))
# Map the dataset to combine feature columns to single tensor
ds = ds.map(lambda l, x0, x1: (tf.stack([x0, x1], axis=1), l))
return ds
def make_remote_dataset(endpoint):
"""Make a TensorFlow Arrow Dataset that reads from a remote Arrow stream."""
# Create the Arrow Dataset from a remote endpoint serving a stream
ds = arrow_io.ArrowStreamDataset(
[endpoint],
columns=(0, 1, 2),
output_types=(tf.int64, tf.float64, tf.float64),
batch_mode='auto')
# Map the dataset to combine feature columns to single tensor
ds = ds.map(lambda l, x0, x1: (tf.stack([x0, x1], axis=1), l))
return ds
def model_fit(ds):
"""Create and fit a Keras logistic regression model."""
# Build the Keras model
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(1, input_shape=(2,), activation='sigmoid'))
model.compile(optimizer='sgd', loss='mean_squared_error', metrics=['accuracy'])
# Fit the model on the given dataset
model.fit(ds, epochs=5, shuffle=False)
return model
if __name__ == '__main__':
# Parse flag to run local or remote dataset
run_remote = False
if len(sys.argv) >= 2 and sys.argv[1] == '--run-remote':
run_remote = True
# Write a sample data as a CSV file
filename = 'sample.csv'
num_records = 1000
write_csv(filename, num_records)
if not run_remote:
print('Running model fit on local file: {}'.format(filename))
make_dataset_fn = partial(make_local_dataset,
filename=filename)
else:
host_addr = '127.0.0.1'
host_port = 8888
serve_dir = './'
print('Running model fit with remote host: {}:{}, serving directory: {}'
.format(host_addr, host_port, serve_dir))
start_server_process(host_addr, host_port, serve_dir)
make_dataset_fn = partial(make_remote_dataset,
endpoint='{}:{}'.format(host_addr, host_port))
# Create the dataset
ds = make_dataset_fn()
# Fit the model
model = model_fit(ds)
print("Fit model with weights: {}".format(model.get_weights()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment