Skip to content

Instantly share code, notes, and snippets.

View frogermcs's full-sized avatar
🤓
Did you do good today?

Mirosław Stanek frogermcs

🤓
Did you do good today?
View GitHub Profile
#Download and unpack MobileNet v2 model from https://www.tensorflow.org/lite/guide/hosted_models
!curl -LO http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz
!curl -LO https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt
!tar -xvzf mobilenet_v2_1.0_224.tgz
# List of unpacked files
# >> mobilenet_v2_1.0_224.ckpt.data-00000-of-00001
# >> mobilenet_v2_1.0_224.ckpt.index
# >> mobilenet_v2_1.0_224.ckpt.meta
# >> mobilenet_v2_1.0_224_eval.pbtxt
!pip install tfcoreml
import tensorflow as tf
import tfcoreml
# TensorFlow 2.0 isn't yet supported. Make sure you use 1.x
print("TensorFlow version {}".format(tf.__version__))
print("Eager mode: ", tf.executing_eagerly())
print("Is GPU available: ", tf.test.is_gpu_available())
# Concatenation of argmax and max value for each row
def max_values_only(data):
argmax_col = np.argmax(data, axis=1).reshape(-1, 1)
max_col = np.max(data, axis=1).reshape(-1, 1)
return np.concatenate([argmax_col, max_col], axis=1)
# Build simplified prediction tables
tf_model_pred_simplified = max_values_only(tf_model_predictions)
tflite_model_pred_simplified = max_values_only(tflite_model_predictions)
tflite_q_model_pred_simplified = max_values_only(tflite_q_model_predictions)
# Concatenate results from all models
all_models_dataframe = pd.concat([tf_pred_dataframe,
tflite_pred_dataframe,
tflite_q_pred_dataframe],
keys=['TF Model', 'TFLite', 'TFLite quantized'],
axis='columns')
# Swap columns to hava side by side comparison
all_models_dataframe = all_models_dataframe.swaplevel(axis='columns')[tflite_pred_dataframe.columns]
# Set batch of images into input tensor
tflite_interpreter.set_tensor(input_details[0]['index'], val_image_batch)
# Run inference
tflite_interpreter.invoke()
# Get prediction results
tflite_model_predictions = tflite_interpreter.get_tensor(output_details[0]['index'])
print("Prediction results shape:", tflite_model_predictions.shape)
# >> Prediction results shape: (32, 5)
tflite_interpreter.resize_tensor_input(input_details[0]['index'], (32, 224, 224, 3))
tflite_interpreter.resize_tensor_input(output_details[0]['index'], (32, 5))
tflite_interpreter.allocate_tensors()
input_details = tflite_interpreter.get_input_details()
output_details = tflite_interpreter.get_output_details()
print("== Input details ==")
print("shape:", input_details[0]['shape'])
print("\n== Output details ==")
tflite_interpreter = tf.lite.Interpreter(model_path=TFLITE_MODEL)
input_details = tflite_interpreter.get_input_details()
output_details = tflite_interpreter.get_output_details()
print("== Input details ==")
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])
print("\n== Output details ==")
print("shape:", output_details[0]['shape'])
!mkdir "tflite_models"
TFLITE_MODEL = "tflite_models/flowers.tflite"
TFLITE_QUANT_MODEL = "tflite_models/flowers_quant.tflite"
# Get the concrete function from the Keras model.
run_model = tf.function(lambda x : flowers_model(x))
# Save the concrete function.
concrete_func = run_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)
# Get predictions for each image
predicted_ids = np.argmax(tf_model_predictions, axis=-1)
predicted_labels = dataset_labels[predicted_ids]
# Print images batch and labels predictions
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
plt.subplot(6,5,n+1)
plt.imshow(val_image_batch[n])
# Get images and labels batch from validation dataset generator
val_image_batch, val_label_batch = next(iter(valid_generator))
true_label_ids = np.argmax(val_label_batch, axis=-1)
print("Validation batch shape:", val_image_batch.shape)
# >> Validation batch shape: (32, 224, 224, 3)
# Get predictions for images batch
tf_model_predictions = flowers_model.predict(val_image_batch)
print("Prediction results shape:", tf_model_predictions.shape)
# >> Prediction results shape: (32, 5)