This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tflite_interpreter.resize_tensor_input(input_details[0]['index'], (32, 224, 224, 3)) | |
tflite_interpreter.resize_tensor_input(output_details[0]['index'], (32, 5)) | |
tflite_interpreter.allocate_tensors() | |
input_details = tflite_interpreter.get_input_details() | |
output_details = tflite_interpreter.get_output_details() | |
print("== Input details ==") | |
print("shape:", input_details[0]['shape']) | |
print("\n== Output details ==") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tflite_interpreter = tf.lite.Interpreter(model_path=TFLITE_MODEL) | |
input_details = tflite_interpreter.get_input_details() | |
output_details = tflite_interpreter.get_output_details() | |
print("== Input details ==") | |
print("shape:", input_details[0]['shape']) | |
print("type:", input_details[0]['dtype']) | |
print("\n== Output details ==") | |
print("shape:", output_details[0]['shape']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
!mkdir "tflite_models" | |
TFLITE_MODEL = "tflite_models/flowers.tflite" | |
TFLITE_QUANT_MODEL = "tflite_models/flowers_quant.tflite" | |
# Get the concrete function from the Keras model. | |
run_model = tf.function(lambda x : flowers_model(x)) | |
# Save the concrete function. | |
concrete_func = run_model.get_concrete_function( | |
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Get predictions for each image | |
predicted_ids = np.argmax(tf_model_predictions, axis=-1) | |
predicted_labels = dataset_labels[predicted_ids] | |
# Print images batch and labels predictions | |
plt.figure(figsize=(10,9)) | |
plt.subplots_adjust(hspace=0.5) | |
for n in range(30): | |
plt.subplot(6,5,n+1) | |
plt.imshow(val_image_batch[n]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Get images and labels batch from validation dataset generator | |
val_image_batch, val_label_batch = next(iter(valid_generator)) | |
true_label_ids = np.argmax(val_label_batch, axis=-1) | |
print("Validation batch shape:", val_image_batch.shape) | |
# >> Validation batch shape: (32, 224, 224, 3) | |
# Get predictions for images batch | |
tf_model_predictions = flowers_model.predict(val_image_batch) | |
print("Prediction results shape:", tf_model_predictions.shape) | |
# >> Prediction results shape: (32, 5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
flowers_model = tf.keras.experimental.load_from_saved_model(FLOWERS_SAVED_MODEL, | |
custom_objects={'KerasLayer':hub.KerasLayer}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
FLOWERS_SAVED_MODEL = "saved_models/flowers" | |
tf.keras.experimental.export_saved_model(model, FLOWERS_SAVED_MODEL) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
steps_per_epoch = np.ceil(train_generator.samples/train_generator.batch_size) | |
val_steps_per_epoch = np.ceil(valid_generator.samples/valid_generator.batch_size) | |
hist = model.fit( | |
train_generator, | |
epochs=10, | |
verbose=1, | |
steps_per_epoch=steps_per_epoch, | |
validation_data=valid_generator, | |
validation_steps=val_steps_per_epoch).history |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Get Flowers dataset | |
data_root = tf.keras.utils.get_file( | |
'flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', | |
untar=True) | |
# Create data generator for training and validation | |
IMAGE_SHAPE = (224, 224) | |
TRAINING_DATA_DIR = str(data_root) | |
datagen_kwargs = dict(rescale=1./255, validation_split=.20) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = tf.keras.Sequential([ | |
hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4", | |
output_shape=[1280], | |
trainable=False), | |
tf.keras.layers.Dropout(0.4), | |
tf.keras.layers.Dense(train_generator.num_classes, activation='softmax') | |
]) |