Created
February 23, 2018 05:12
-
-
Save elmarhaussmann/6a2804a334baff9aed79aa7f1ed4b038 to your computer and use it in GitHub Desktop.
Character based text classification with TPUEstimator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Based on the example from the TensorFlow repository: https://github.com/tensorflow/tensorflow/ | |
# https://github.com/tensorflow/tensorflow/blob/671baf080238025da9698ea980cd9504005f727c/tensorflow/examples/learn/text_classification_character_rnn.py | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import argparse | |
import sys | |
import numpy as np | |
import pandas | |
import tensorflow as tf | |
from tensorflow.contrib.tpu.python.tpu import tpu_config | |
from tensorflow.contrib.tpu.python.tpu import tpu_estimator | |
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer | |
from tensorflow.contrib.rnn import MultiRNNCell, static_bidirectional_rnn | |
FLAGS = tf.flags.FLAGS | |
MAX_DOCUMENT_LENGTH = 128 | |
NUM_LAYERS = 1 | |
HIDDEN_SIZE = 1024 | |
MAX_LABEL = 15 | |
CHARS_FEATURE = 'chars' # Name of the input character feature. | |
tf.flags.DEFINE_bool( | |
'use_tpu', True, | |
help=('Use TPU to execute the model for training and evaluation. If' | |
' --use_tpu=false, will use whatever devices are available to' | |
' TensorFlow by default (e.g. CPU and GPU)')) | |
tf.flags.DEFINE_string( | |
'tpu_name', default=None, | |
help='Name of the Cloud TPU for Cluster Resolvers. You must specify either ' | |
'this flag or --master.') | |
# Cloud TPU Cluster Resolvers | |
tf.flags.DEFINE_string( | |
'gcp_project', default=None, | |
help='Project name for the Cloud TPU-enabled project. If not specified, we ' | |
'will attempt to automatically detect the GCE project from metadata.') | |
tf.flags.DEFINE_string( | |
'tpu_zone', default=None, | |
help='GCE zone where the Cloud TPU is located in. If not specified, we ' | |
'will attempt to automatically detect the GCE project from metadata.') | |
tf.flags.DEFINE_string( | |
'master', default=None, | |
help='gRPC URL of the master (i.e. grpc://ip.address.of.tpu:8470). You ' | |
'must specify either this flag or --tpu_name.') | |
tf.flags.DEFINE_integer( | |
'num_cores', default=8, | |
help=('Number of TPU cores. For a single TPU device, this is 8 because each' | |
' TPU has 4 chips each with 2 cores.')) | |
tf.flags.DEFINE_string( | |
'model_dir', default=None, | |
help=('The directory where the model and training/evaluation summaries are' | |
' stored.')) | |
tf.flags.DEFINE_integer( | |
'iterations_per_loop', default=200, | |
help=('Number of steps to run on TPU before outfeeding metrics to the CPU.' | |
' If the number of iterations in the loop would exceed the number of' | |
' train steps, the loop will exit before reaching' | |
' --iterations_per_loop. The larger this value is, the higher the' | |
' utilization on the TPU.')) | |
tf.flags.DEFINE_integer( | |
'train_steps', default=2000, | |
help=('The number of steps to use for training.')) | |
tf.flags.DEFINE_integer("eval_steps", 100, | |
"Total number of evaluation steps. If `0`, evaluation " | |
"after training is skipped.") | |
tf.flags.DEFINE_integer( | |
'train_batch_size', default=32, help='Batch size for training.') | |
tf.flags.DEFINE_integer( | |
'eval_batch_size', default=32, help='Batch size for evaluation.') | |
def char_rnn_model(features, labels, mode, params): | |
"""Character level recurrent neural network model to predict classes.""" | |
if mode == tf.estimator.ModeKeys.PREDICT: | |
raise RuntimeError("mode {} is not supported yet".format(mode)) | |
byte_vectors = tf.one_hot(features[CHARS_FEATURE], 256, 1., 0.) | |
byte_list = tf.unstack(byte_vectors, axis=1) | |
def lstm_cell(): | |
return tf.nn.rnn_cell.LSTMCell(HIDDEN_SIZE) | |
fw_cell = MultiRNNCell(cells=[lstm_cell() for _ in range(NUM_LAYERS)]) | |
bw_cell = MultiRNNCell(cells=[lstm_cell() for _ in range(NUM_LAYERS)]) | |
output, _, _ = static_bidirectional_rnn(fw_cell, bw_cell, byte_list, | |
dtype=tf.float32) | |
logits = tf.layers.dense(output[-1], MAX_LABEL, activation=None) | |
# Could also use a simple GRU instead | |
#cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE) | |
#_, encoding = tf.nn.static_rnn(cell, byte_list, dtype=tf.float32) | |
#logits = tf.layers.dense(encoding, MAX_LABEL, activation=None) | |
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) | |
if mode == tf.estimator.ModeKeys.TRAIN: | |
optimizer = tf.train.AdamOptimizer(learning_rate=0.01) | |
if FLAGS.use_tpu: | |
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) | |
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) | |
return tpu_estimator.TPUEstimatorSpec( | |
mode=mode, | |
loss=loss, | |
train_op=train_op) | |
if mode == tf.estimator.ModeKeys.EVAL: | |
def metric_fn(labels, logits): | |
predictions = tf.argmax(logits, axis=1) | |
accuracy = tf.metrics.accuracy(labels, predictions) | |
return {"accuracy": accuracy} | |
return tpu_estimator.TPUEstimatorSpec( | |
mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits])) | |
class DBPedia(): | |
def __init__(self, size='small'): | |
self.load_dbpedia(size) | |
def load_dbpedia(self, size): | |
print("Loading DBPedia data...") | |
dbpedia = tf.contrib.learn.datasets.load_dataset( | |
'dbpedia', size=size, test_with_fake_data=False) | |
x_train = pandas.DataFrame(dbpedia.train.data)[1] | |
x_test = pandas.DataFrame(dbpedia.test.data)[1] | |
# Process vocabulary | |
char_processor = tf.contrib.learn.preprocessing.ByteProcessor( | |
MAX_DOCUMENT_LENGTH) | |
x_train = list(char_processor.fit_transform(x_train)) | |
x_test = list(char_processor.transform(x_test)) | |
self.x_train = np.array(x_train, dtype=np.int32) | |
self.x_test = np.array(x_test, dtype=np.int32) | |
self.y_train = pandas.Series(dbpedia.train.target, dtype=np.int32) | |
self.y_test = pandas.Series(dbpedia.test.target, dtype=np.int32) | |
print("Train shape:", self.x_train.shape) | |
print("Test shape:", self.x_test.shape) | |
def input_fn(self, params): | |
batch_size = params["batch_size"] | |
tensors = ({CHARS_FEATURE: self.x_train}, self.y_train) | |
dataset = tf.data.Dataset.from_tensor_slices(tensors) | |
# Shuffle, repeat, and batch the examples. | |
dataset = dataset.cache().repeat() \ | |
.shuffle(self.x_train.shape[0]) \ | |
.apply(tf.contrib.data.batch_and_drop_remainder(batch_size)) | |
return dataset | |
def eval_input_fn(self, params): | |
batch_size = params["batch_size"] | |
tensors = ({CHARS_FEATURE: self.x_test}, self.y_test) | |
dataset = tf.data.Dataset.from_tensor_slices(tensors) | |
# Shuffle, repeat, and batch the examples. | |
dataset = dataset.cache().repeat() \ | |
.apply(tf.contrib.data.batch_and_drop_remainder(batch_size)) | |
return dataset | |
def main(unused_argv): | |
if FLAGS.use_tpu: | |
# Determine the gRPC URL of the TPU device to use | |
if FLAGS.master is None and FLAGS.tpu_name is None: | |
raise RuntimeError('You must specify either --master or --tpu_name.') | |
if FLAGS.master is not None: | |
if FLAGS.tpu_name is not None: | |
tf.logging.warn('Both --master and --tpu_name are set. Ignoring' | |
' --tpu_name and using --master.') | |
tpu_grpc_url = FLAGS.master | |
else: | |
tpu_cluster_resolver = ( | |
tf.contrib.cluster_resolver.TPUClusterResolver( | |
tpu_names=[FLAGS.tpu_name], | |
zone=FLAGS.tpu_zone, | |
project=FLAGS.gcp_project)) | |
tpu_grpc_url = tpu_cluster_resolver.get_master() | |
else: | |
# URL is unused if running locally without TPU | |
tpu_grpc_url = None | |
config = tpu_config.RunConfig( | |
master=tpu_grpc_url, | |
evaluation_master=tpu_grpc_url, | |
model_dir=FLAGS.model_dir, | |
save_summary_steps=200, | |
tpu_config=tpu_config.TPUConfig( | |
iterations_per_loop=FLAGS.iterations_per_loop, | |
num_shards=FLAGS.num_cores)) | |
# Load data | |
train_data = DBPedia(size='small') | |
# Build model | |
classifier = tpu_estimator.TPUEstimator( | |
use_tpu=FLAGS.use_tpu, | |
model_fn=char_rnn_model, | |
config=config, | |
eval_batch_size=FLAGS.eval_batch_size, | |
train_batch_size=FLAGS.train_batch_size) | |
step_hook = tf.train.StepCounterHook(output_dir=FLAGS.model_dir, | |
every_n_steps=10) | |
# Train model for some steps | |
classifier.train(input_fn=train_data.input_fn, max_steps=FLAGS.train_steps, | |
hooks=[step_hook]) | |
# Evaluate model | |
scores = classifier.evaluate(input_fn=train_data.eval_input_fn, | |
steps=FLAGS.eval_steps) | |
print('Accuracy: {0:f}'.format(scores['accuracy'])) | |
if __name__ == '__main__': | |
tf.logging.set_verbosity(tf.logging.INFO) | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment