This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Download and unpack MobileNet v2 model from https://www.tensorflow.org/lite/guide/hosted_models | |
!curl -LO http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz | |
!curl -LO https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt | |
!tar -xvzf mobilenet_v2_1.0_224.tgz | |
# List of unpacked files | |
# >> mobilenet_v2_1.0_224.ckpt.data-00000-of-00001 | |
# >> mobilenet_v2_1.0_224.ckpt.index | |
# >> mobilenet_v2_1.0_224.ckpt.meta | |
# >> mobilenet_v2_1.0_224_eval.pbtxt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
!pip install tfcoreml | |
import tensorflow as tf | |
import tfcoreml | |
# TensorFlow 2.0 isn't yet supported. Make sure you use 1.x | |
print("TensorFlow version {}".format(tf.__version__)) | |
print("Eager mode: ", tf.executing_eagerly()) | |
print("Is GPU available: ", tf.test.is_gpu_available()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Concatenation of argmax and max value for each row | |
def max_values_only(data): | |
argmax_col = np.argmax(data, axis=1).reshape(-1, 1) | |
max_col = np.max(data, axis=1).reshape(-1, 1) | |
return np.concatenate([argmax_col, max_col], axis=1) | |
# Build simplified prediction tables | |
tf_model_pred_simplified = max_values_only(tf_model_predictions) | |
tflite_model_pred_simplified = max_values_only(tflite_model_predictions) | |
tflite_q_model_pred_simplified = max_values_only(tflite_q_model_predictions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Concatenate results from all models | |
all_models_dataframe = pd.concat([tf_pred_dataframe, | |
tflite_pred_dataframe, | |
tflite_q_pred_dataframe], | |
keys=['TF Model', 'TFLite', 'TFLite quantized'], | |
axis='columns') | |
# Swap columns to hava side by side comparison | |
all_models_dataframe = all_models_dataframe.swaplevel(axis='columns')[tflite_pred_dataframe.columns] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Set batch of images into input tensor | |
tflite_interpreter.set_tensor(input_details[0]['index'], val_image_batch) | |
# Run inference | |
tflite_interpreter.invoke() | |
# Get prediction results | |
tflite_model_predictions = tflite_interpreter.get_tensor(output_details[0]['index']) | |
print("Prediction results shape:", tflite_model_predictions.shape) | |
# >> Prediction results shape: (32, 5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tflite_interpreter.resize_tensor_input(input_details[0]['index'], (32, 224, 224, 3)) | |
tflite_interpreter.resize_tensor_input(output_details[0]['index'], (32, 5)) | |
tflite_interpreter.allocate_tensors() | |
input_details = tflite_interpreter.get_input_details() | |
output_details = tflite_interpreter.get_output_details() | |
print("== Input details ==") | |
print("shape:", input_details[0]['shape']) | |
print("\n== Output details ==") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tflite_interpreter = tf.lite.Interpreter(model_path=TFLITE_MODEL) | |
input_details = tflite_interpreter.get_input_details() | |
output_details = tflite_interpreter.get_output_details() | |
print("== Input details ==") | |
print("shape:", input_details[0]['shape']) | |
print("type:", input_details[0]['dtype']) | |
print("\n== Output details ==") | |
print("shape:", output_details[0]['shape']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
!mkdir "tflite_models" | |
TFLITE_MODEL = "tflite_models/flowers.tflite" | |
TFLITE_QUANT_MODEL = "tflite_models/flowers_quant.tflite" | |
# Get the concrete function from the Keras model. | |
run_model = tf.function(lambda x : flowers_model(x)) | |
# Save the concrete function. | |
concrete_func = run_model.get_concrete_function( | |
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Get predictions for each image | |
predicted_ids = np.argmax(tf_model_predictions, axis=-1) | |
predicted_labels = dataset_labels[predicted_ids] | |
# Print images batch and labels predictions | |
plt.figure(figsize=(10,9)) | |
plt.subplots_adjust(hspace=0.5) | |
for n in range(30): | |
plt.subplot(6,5,n+1) | |
plt.imshow(val_image_batch[n]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Get images and labels batch from validation dataset generator | |
val_image_batch, val_label_batch = next(iter(valid_generator)) | |
true_label_ids = np.argmax(val_label_batch, axis=-1) | |
print("Validation batch shape:", val_image_batch.shape) | |
# >> Validation batch shape: (32, 224, 224, 3) | |
# Get predictions for images batch | |
tf_model_predictions = flowers_model.predict(val_image_batch) | |
print("Prediction results shape:", tf_model_predictions.shape) | |
# >> Prediction results shape: (32, 5) |