Skip to content

Instantly share code, notes, and snippets.

@jiewpeng
Last active December 5, 2018 14:43
Show Gist options
  • Save jiewpeng/87b739616deeb2b4904883cfeca3abd4 to your computer and use it in GitHub Desktop.
Save jiewpeng/87b739616deeb2b4904883cfeca3abd4 to your computer and use it in GitHub Desktop.
Keras Model within Estimator Function
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import shutil
def create_model(max_seq_len, embedding_size):
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dropout(0.5, input_shape=(max_seq_len, embedding_size)))
model.add(tf.keras.layers.SeparableConv1D(8, 3, padding='same', activation=tf.nn.leaky_relu))
model.add(tf.keras.layers.GlobalAveragePooling1D())
model.add(tf.keras.layers.Dense(2, activation='softmax'))
return model
def pad_seq(text, max_seq_len):
reshaped_input = tf.reshape(text, [-1])
split = tf.strings.split(reshaped_input)
split = tf.sparse.to_dense(split, default_value='')
seq_len = tf.shape(split)[1]
batch_size = tf.shape(split)[0]
split = tf.cond(
seq_len < max_seq_len,
lambda: tf.pad(split, [[0, 0], [0, max_seq_len - seq_len]], constant_values=''),
lambda: tf.slice(split, [0, 0], [batch_size, max_seq_len])
)
return split
def model_fn(features, labels, mode, params):
"""The model_fn argument for creating an Estimator."""
tfhub_url = params.get('tfhub_url', 'https://tfhub.dev/google/nnlm-en-dim128/1')
embedding_trainable = params.get('embedding_trainable', False)
max_seq_len = params.get('max_seq_len', 5)
embedding_size_dict = {
'https://tfhub.dev/google/nnlm-en-dim128/1': 128,
'https://tfhub.dev/google/Wiki-words-250/1': 250,
}
model = create_model(max_seq_len, embedding_size_dict[tfhub_url])
embed = hub.Module(tfhub_url, trainable=embedding_trainable)
text = features
if isinstance(text, dict):
text = text['text']
text_seq = pad_seq(text, max_seq_len)
embeddings = tf.map_fn(embed, text_seq, dtype=tf.float32)
if mode == tf.estimator.ModeKeys.PREDICT:
logits = model(embeddings, training=False)
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits),
}
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
'predict': tf.estimator.export.PredictOutput(predictions)
})
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
# If we are running multi-GPU, we need to wrap the optimizer.
if params.get('multi_gpu'):
optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)
logits = model(embeddings, training=True)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
accuracy = tf.metrics.accuracy(
labels=labels, predictions=tf.argmax(logits, axis=1))
# Name tensors to be logged with LoggingTensorHook.
tf.identity(loss, 'cross_entropy')
tf.identity(accuracy[1], name='train_accuracy')
# Save accuracy scalar to Tensorboard output.
tf.summary.scalar('train_accuracy', accuracy[1])
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
if mode == tf.estimator.ModeKeys.EVAL:
logits = model(image, training=False)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL,
loss=loss,
eval_metric_ops={
'accuracy':
tf.metrics.accuracy(
labels=labels, predictions=tf.argmax(logits, axis=1)),
})
estimator = tf.estimator.Estimator(model_fn, model_dir='./model_trained')
x = np.array(['the quick brown fox', 'jumps over a lazy dog'])
y = np.array([0, 1])
input_fn = tf.estimator.inputs.numpy_input_fn({'text': x}, y=y, shuffle=False)
shutil.rmtree('./model_trained', ignore_errors=True)
estimator.train(input_fn, steps=10)
for pred in estimator.predict(input_fn):
print(pred)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment