Skip to content

Instantly share code, notes, and snippets.

@eggie5
Last active April 3, 2020 08:24
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save eggie5/5ca00b19c2d7fbeade8d66f50c522793 to your computer and use it in GitHub Desktop.
Save eggie5/5ca00b19c2d7fbeade8d66f50c522793 to your computer and use it in GitHub Desktop.
import pandas as pd
import tensorflow as tf
import dask.dataframe as dd
from dask.distributed import Client, LocalCluster
from tensorflow.python.saved_model import loader
def encode_factory(sess, export_path:str):
"""Loads TF SavedModel and returns a callable"""
output_tensor_names = ["input_layer/concat:0"]
loader.load(sess, "serve", export_path)
def encode(docs):
inputs_feed_dict = {"input_example_tensor:0": docs}
batch = sess.run(output_tensor_names, feed_dict=inputs_feed_dict)
return batch
return encode
def map_fn(pdf, encoder):
encode = encoder()
embedded_docs = encode(pdf.docs) #run TF graph on batch of docs
pdf["encoded"] = tuple(embedded_docs) #tuple for pandas
return pdf
#Start Dask scheduler
cluster = LocalCluster()
client = Client(cluster)
#Extract
docs_ddf = dd.read_csv("s3://.../data/*.fasttext", names=["docs"]) #fasttext format
docs_ddf = docs_ddf.repartition(npartitions=npartitions)
docs_ddf = client.persist(docs_ddf) #cache
#broadcast TF closure to workers
classifier_future = client.scatter(encode_factory, broadcast=True)
#Tranform: run the TF model on partitions
encoded_ddf = docs_ddf.map_partitions(map_fn, classifier_future)
#Load
encoded_ddf.to_csv("s3://.../index/v1/")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment