Skip to content

Instantly share code, notes, and snippets.

@chunseoklee
Created June 11, 2019 12:41
Show Gist options
  • Save chunseoklee/7050f658c1b93648500d1e9b96f93cd1 to your computer and use it in GitHub Desktop.
Save chunseoklee/7050f658c1b93648500d1e9b96f93cd1 to your computer and use it in GitHub Desktop.
last_model on tf experimental 1.14
# Note this needs to happen before import tensorflow.
import os
import sys
import tensorflow as tf
import argparse
os.environ['TF_ENABLE_CONTROL_FLOW_V2'] = '1'
class MnistLstmModel(object):
"""Build a simple LSTM based MNIST model.
Attributes:
time_steps: The maximum length of the time_steps, but since we're just using
the 'width' dimension as time_steps, it's actually a fixed number.
input_size: The LSTM layer input size.
num_lstm_layer: Number of LSTM layers for the stacked LSTM cell case.
num_lstm_units: Number of units in the LSTM cell.
units: The units for the last layer.
num_class: Number of classes to predict.
"""
def __init__(self, time_steps, input_size, num_lstm_layer, num_lstm_units,
units, num_class):
self.time_steps = time_steps
self.input_size = input_size
self.num_lstm_layer = num_lstm_layer
self.num_lstm_units = num_lstm_units
self.units = units
self.num_class = num_class
def build_model(self):
"""Build the model using the given configs.
Returns:
x: The input placehoder tensor.
logits: The logits of the output.
output_class: The prediction.
"""
x = tf.placeholder(
'float32', [None, self.time_steps, self.input_size], name='INPUT')
lstm_layers = []
for _ in range(self.num_lstm_layer):
lstm_layers.append(
# Important:
#
# Note here, we use `tf.lite.experimental.nn.TFLiteLSTMCell`
# (OpHinted LSTMCell).
tf.lite.experimental.nn.TFLiteLSTMCell(
self.num_lstm_units, forget_bias=0))
# Weights and biases for output softmax layer.
out_weights = tf.Variable(tf.random_normal([self.units, self.num_class]))
out_bias = tf.Variable(tf.zeros([self.num_class]))
# Transpose input x to make it time major.
lstm_inputs = tf.transpose(x, perm=[1, 0, 2])
lstm_cells = tf.keras.layers.StackedRNNCells(lstm_layers)
# Important:
#
# Note here, we use `tf.lite.experimental.nn.dynamic_rnn` and `time_major`
# is set to True.
outputs, _ = tf.lite.experimental.nn.dynamic_rnn(
lstm_cells, lstm_inputs, dtype='float32', time_major=True)
# Transpose the outputs back to [batch, time, output]
outputs = tf.transpose(outputs, perm=[1, 0, 2])
outputs = tf.unstack(outputs, axis=1)
logits = tf.matmul(outputs[-1], out_weights) + out_bias
output_class = tf.nn.softmax(logits, name='OUTPUT_CLASS')
return x, logits, output_class
def train(model,
model_dir,
batch_size=20,
learning_rate=0.001,
train_steps=2000,
eval_steps=500,
save_every_n_steps=1000):
"""Train & save the MNIST recognition model."""
# Train & test dataset.
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_iterator = train_dataset.shuffle(
buffer_size=1000).batch(batch_size).repeat().make_one_shot_iterator()
x, logits, output_class = model.build_model()
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_iterator = test_dataset.batch(
batch_size).repeat().make_one_shot_iterator()
# input label placeholder
y = tf.placeholder(tf.int32, [
None,
])
one_hot_labels = tf.one_hot(y, depth=model.num_class)
# Loss function
loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=one_hot_labels))
correct = tf.nn.in_top_k(output_class, y, 1)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
# Optimization
opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
# Initialize variables
init = tf.global_variables_initializer()
saver = tf.train.Saver()
batch_x, batch_y = train_iterator.get_next()
batch_test_x, batch_test_y = test_iterator.get_next()
with tf.Session() as sess:
sess.run([init])
for i in range(train_steps):
batch_x_value, batch_y_value = sess.run([batch_x, batch_y])
_, loss_value = sess.run([opt, loss],
feed_dict={
x: batch_x_value,
y: batch_y_value
})
if i % 100 == 0:
tf.logging.info('Training step %d, loss is %f' % (i, loss_value))
if i > 0 and i % save_every_n_steps == 0:
accuracy_sum = 0.0
for _ in range(eval_steps):
test_x_value, test_y_value = sess.run([batch_test_x, batch_test_y])
accuracy_value = sess.run(
accuracy, feed_dict={
x: test_x_value,
y: test_y_value
})
accuracy_sum += accuracy_value
tf.logging.info('Training step %d, accuracy is %f' %
(i, accuracy_sum / (eval_steps * 1.0)))
saver.save(sess, model_dir)
def export(model, model_dir, tflite_model_file,
use_post_training_quantize=True):
"""Export trained model to tflite model."""
tf.reset_default_graph()
x, _, output_class = model.build_model()
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, model_dir)
# Convert to Tflite model.
converter = tf.lite.TFLiteConverter.from_session(sess, [x], [output_class])
converter.post_training_quantize = use_post_training_quantize
tflite = converter.convert()
with open(tflite_model_file, 'wb') as f:
f.write(tflite)
def train_and_export(parsed_flags):
"""Train the MNIST LSTM model and export to TfLite."""
model = MnistLstmModel(
time_steps=28,
input_size=28,
num_lstm_layer=2,
num_lstm_units=64,
units=64,
num_class=10)
tf.logging.info('Starts training...')
train(model, parsed_flags.model_dir)
tf.logging.info('Finished training, starts exporting to tflite to %s ...' %
parsed_flags.tflite_model_file)
export(model, parsed_flags.model_dir, parsed_flags.tflite_model_file,
parsed_flags.use_post_training_quantize)
tf.logging.info(
'Finished exporting, model is %s' % parsed_flags.tflite_model_file)
def run_main(_):
"""Main in the TfLite LSTM tutorial."""
parser = argparse.ArgumentParser(
description=('Train a MNIST recognition model then export to TfLite.'))
parser.add_argument(
'--model_dir',
type=str,
help='Directory where the models will store.',
required=True)
parser.add_argument(
'--tflite_model_file',
type=str,
help='Full filepath to the exported tflite model file.',
required=True)
parser.add_argument(
'--use_post_training_quantize',
action='store_true',
default=True,
help='Whether or not to use post_training_quatize.')
parsed_flags, _ = parser.parse_known_args()
train_and_export(parsed_flags)
#def main():
# app.run(main=run_main, argv=sys.argv[:1])
if __name__ == '__main__':
run_main(sys.argv[:1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment