Skip to content

Instantly share code, notes, and snippets.

@crypt3lx2k
Last active November 9, 2018 17:21
Show Gist options
  • Save crypt3lx2k/2e5da6a94a180455290284a9dca4e143 to your computer and use it in GitHub Desktop.
Save crypt3lx2k/2e5da6a94a180455290284a9dca4e143 to your computer and use it in GitHub Desktop.
Training and exporting a keras model to the TFLite format
#! /usr/bin/env python
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)
model.save('keras_mnist.h5')
#! /usr/bin/env python
import tensorflow.contrib.lite as lite
converter = lite.TFLiteConverter.from_keras_model_file('keras_mnist.h5')
converter.post_training_quantize = True
converter.inference_input_type = lite.constants.QUANTIZED_UINT8
converter.quantized_input_stats = {"flatten_input" : (0.0, 255.0)}
flatbuffer = converter.convert()
with open('keras_mnist.tflite', 'wb') as outfile:
outfile.write(flatbuffer)
#! /usr/bin/env python
import numpy as np
import tensorflow as tf
import tensorflow.contrib.lite as lite
mnist = tf.keras.datasets.mnist
batch_size = 32
_,(x_test, y_test) = mnist.load_data()
interpreter = lite.Interpreter('keras_mnist.tflite')
input_info = interpreter.get_input_details()[0]
output_info = interpreter.get_output_details()[0]
interpreter.resize_tensor_input(input_info['index'], (batch_size, 28, 28))
interpreter.allocate_tensors()
interpreter.set_tensor(input_info['index'], x_test[0:batch_size])
interpreter.invoke()
probs = interpreter.get_tensor(output_info['index'])
print('predicted={}, label={}'.format(np.argmax(probs, axis=-1), y_test[0:batch_size]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment