Skip to content

Instantly share code, notes, and snippets.

@elmarhaussmann
Created February 23, 2018 05:12
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save elmarhaussmann/6a2804a334baff9aed79aa7f1ed4b038 to your computer and use it in GitHub Desktop.
Save elmarhaussmann/6a2804a334baff9aed79aa7f1ed4b038 to your computer and use it in GitHub Desktop.
Character based text classification with TPUEstimator
# Based on the example from the TensorFlow repository: https://github.com/tensorflow/tensorflow/
# https://github.com/tensorflow/tensorflow/blob/671baf080238025da9698ea980cd9504005f727c/tensorflow/examples/learn/text_classification_character_rnn.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import numpy as np
import pandas
import tensorflow as tf
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
from tensorflow.contrib.rnn import MultiRNNCell, static_bidirectional_rnn
FLAGS = tf.flags.FLAGS
MAX_DOCUMENT_LENGTH = 128
NUM_LAYERS = 1
HIDDEN_SIZE = 1024
MAX_LABEL = 15
CHARS_FEATURE = 'chars' # Name of the input character feature.
tf.flags.DEFINE_bool(
'use_tpu', True,
help=('Use TPU to execute the model for training and evaluation. If'
' --use_tpu=false, will use whatever devices are available to'
' TensorFlow by default (e.g. CPU and GPU)'))
tf.flags.DEFINE_string(
'tpu_name', default=None,
help='Name of the Cloud TPU for Cluster Resolvers. You must specify either '
'this flag or --master.')
# Cloud TPU Cluster Resolvers
tf.flags.DEFINE_string(
'gcp_project', default=None,
help='Project name for the Cloud TPU-enabled project. If not specified, we '
'will attempt to automatically detect the GCE project from metadata.')
tf.flags.DEFINE_string(
'tpu_zone', default=None,
help='GCE zone where the Cloud TPU is located in. If not specified, we '
'will attempt to automatically detect the GCE project from metadata.')
tf.flags.DEFINE_string(
'master', default=None,
help='gRPC URL of the master (i.e. grpc://ip.address.of.tpu:8470). You '
'must specify either this flag or --tpu_name.')
tf.flags.DEFINE_integer(
'num_cores', default=8,
help=('Number of TPU cores. For a single TPU device, this is 8 because each'
' TPU has 4 chips each with 2 cores.'))
tf.flags.DEFINE_string(
'model_dir', default=None,
help=('The directory where the model and training/evaluation summaries are'
' stored.'))
tf.flags.DEFINE_integer(
'iterations_per_loop', default=200,
help=('Number of steps to run on TPU before outfeeding metrics to the CPU.'
' If the number of iterations in the loop would exceed the number of'
' train steps, the loop will exit before reaching'
' --iterations_per_loop. The larger this value is, the higher the'
' utilization on the TPU.'))
tf.flags.DEFINE_integer(
'train_steps', default=2000,
help=('The number of steps to use for training.'))
tf.flags.DEFINE_integer("eval_steps", 100,
"Total number of evaluation steps. If `0`, evaluation "
"after training is skipped.")
tf.flags.DEFINE_integer(
'train_batch_size', default=32, help='Batch size for training.')
tf.flags.DEFINE_integer(
'eval_batch_size', default=32, help='Batch size for evaluation.')
def char_rnn_model(features, labels, mode, params):
"""Character level recurrent neural network model to predict classes."""
if mode == tf.estimator.ModeKeys.PREDICT:
raise RuntimeError("mode {} is not supported yet".format(mode))
byte_vectors = tf.one_hot(features[CHARS_FEATURE], 256, 1., 0.)
byte_list = tf.unstack(byte_vectors, axis=1)
def lstm_cell():
return tf.nn.rnn_cell.LSTMCell(HIDDEN_SIZE)
fw_cell = MultiRNNCell(cells=[lstm_cell() for _ in range(NUM_LAYERS)])
bw_cell = MultiRNNCell(cells=[lstm_cell() for _ in range(NUM_LAYERS)])
output, _, _ = static_bidirectional_rnn(fw_cell, bw_cell, byte_list,
dtype=tf.float32)
logits = tf.layers.dense(output[-1], MAX_LABEL, activation=None)
# Could also use a simple GRU instead
#cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE)
#_, encoding = tf.nn.static_rnn(cell, byte_list, dtype=tf.float32)
#logits = tf.layers.dense(encoding, MAX_LABEL, activation=None)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
if FLAGS.use_tpu:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
return tpu_estimator.TPUEstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op)
if mode == tf.estimator.ModeKeys.EVAL:
def metric_fn(labels, logits):
predictions = tf.argmax(logits, axis=1)
accuracy = tf.metrics.accuracy(labels, predictions)
return {"accuracy": accuracy}
return tpu_estimator.TPUEstimatorSpec(
mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits]))
class DBPedia():
def __init__(self, size='small'):
self.load_dbpedia(size)
def load_dbpedia(self, size):
print("Loading DBPedia data...")
dbpedia = tf.contrib.learn.datasets.load_dataset(
'dbpedia', size=size, test_with_fake_data=False)
x_train = pandas.DataFrame(dbpedia.train.data)[1]
x_test = pandas.DataFrame(dbpedia.test.data)[1]
# Process vocabulary
char_processor = tf.contrib.learn.preprocessing.ByteProcessor(
MAX_DOCUMENT_LENGTH)
x_train = list(char_processor.fit_transform(x_train))
x_test = list(char_processor.transform(x_test))
self.x_train = np.array(x_train, dtype=np.int32)
self.x_test = np.array(x_test, dtype=np.int32)
self.y_train = pandas.Series(dbpedia.train.target, dtype=np.int32)
self.y_test = pandas.Series(dbpedia.test.target, dtype=np.int32)
print("Train shape:", self.x_train.shape)
print("Test shape:", self.x_test.shape)
def input_fn(self, params):
batch_size = params["batch_size"]
tensors = ({CHARS_FEATURE: self.x_train}, self.y_train)
dataset = tf.data.Dataset.from_tensor_slices(tensors)
# Shuffle, repeat, and batch the examples.
dataset = dataset.cache().repeat() \
.shuffle(self.x_train.shape[0]) \
.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
return dataset
def eval_input_fn(self, params):
batch_size = params["batch_size"]
tensors = ({CHARS_FEATURE: self.x_test}, self.y_test)
dataset = tf.data.Dataset.from_tensor_slices(tensors)
# Shuffle, repeat, and batch the examples.
dataset = dataset.cache().repeat() \
.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
return dataset
def main(unused_argv):
if FLAGS.use_tpu:
# Determine the gRPC URL of the TPU device to use
if FLAGS.master is None and FLAGS.tpu_name is None:
raise RuntimeError('You must specify either --master or --tpu_name.')
if FLAGS.master is not None:
if FLAGS.tpu_name is not None:
tf.logging.warn('Both --master and --tpu_name are set. Ignoring'
' --tpu_name and using --master.')
tpu_grpc_url = FLAGS.master
else:
tpu_cluster_resolver = (
tf.contrib.cluster_resolver.TPUClusterResolver(
tpu_names=[FLAGS.tpu_name],
zone=FLAGS.tpu_zone,
project=FLAGS.gcp_project))
tpu_grpc_url = tpu_cluster_resolver.get_master()
else:
# URL is unused if running locally without TPU
tpu_grpc_url = None
config = tpu_config.RunConfig(
master=tpu_grpc_url,
evaluation_master=tpu_grpc_url,
model_dir=FLAGS.model_dir,
save_summary_steps=200,
tpu_config=tpu_config.TPUConfig(
iterations_per_loop=FLAGS.iterations_per_loop,
num_shards=FLAGS.num_cores))
# Load data
train_data = DBPedia(size='small')
# Build model
classifier = tpu_estimator.TPUEstimator(
use_tpu=FLAGS.use_tpu,
model_fn=char_rnn_model,
config=config,
eval_batch_size=FLAGS.eval_batch_size,
train_batch_size=FLAGS.train_batch_size)
step_hook = tf.train.StepCounterHook(output_dir=FLAGS.model_dir,
every_n_steps=10)
# Train model for some steps
classifier.train(input_fn=train_data.input_fn, max_steps=FLAGS.train_steps,
hooks=[step_hook])
# Evaluate model
scores = classifier.evaluate(input_fn=train_data.eval_input_fn,
steps=FLAGS.eval_steps)
print('Accuracy: {0:f}'.format(scores['accuracy']))
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment