Skip to content

Instantly share code, notes, and snippets.

View frogermcs's full-sized avatar
🤓
Did you do good today?

Mirosław Stanek frogermcs

🤓
Did you do good today?
View GitHub Profile
tflite_interpreter.resize_tensor_input(input_details[0]['index'], (32, 224, 224, 3))
tflite_interpreter.resize_tensor_input(output_details[0]['index'], (32, 5))
tflite_interpreter.allocate_tensors()
input_details = tflite_interpreter.get_input_details()
output_details = tflite_interpreter.get_output_details()
print("== Input details ==")
print("shape:", input_details[0]['shape'])
print("\n== Output details ==")
tflite_interpreter = tf.lite.Interpreter(model_path=TFLITE_MODEL)
input_details = tflite_interpreter.get_input_details()
output_details = tflite_interpreter.get_output_details()
print("== Input details ==")
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])
print("\n== Output details ==")
print("shape:", output_details[0]['shape'])
!mkdir "tflite_models"
TFLITE_MODEL = "tflite_models/flowers.tflite"
TFLITE_QUANT_MODEL = "tflite_models/flowers_quant.tflite"
# Get the concrete function from the Keras model.
run_model = tf.function(lambda x : flowers_model(x))
# Save the concrete function.
concrete_func = run_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)
# Get predictions for each image
predicted_ids = np.argmax(tf_model_predictions, axis=-1)
predicted_labels = dataset_labels[predicted_ids]
# Print images batch and labels predictions
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
plt.subplot(6,5,n+1)
plt.imshow(val_image_batch[n])
# Get images and labels batch from validation dataset generator
val_image_batch, val_label_batch = next(iter(valid_generator))
true_label_ids = np.argmax(val_label_batch, axis=-1)
print("Validation batch shape:", val_image_batch.shape)
# >> Validation batch shape: (32, 224, 224, 3)
# Get predictions for images batch
tf_model_predictions = flowers_model.predict(val_image_batch)
print("Prediction results shape:", tf_model_predictions.shape)
# >> Prediction results shape: (32, 5)
flowers_model = tf.keras.experimental.load_from_saved_model(FLOWERS_SAVED_MODEL,
custom_objects={'KerasLayer':hub.KerasLayer})
FLOWERS_SAVED_MODEL = "saved_models/flowers"
tf.keras.experimental.export_saved_model(model, FLOWERS_SAVED_MODEL)
steps_per_epoch = np.ceil(train_generator.samples/train_generator.batch_size)
val_steps_per_epoch = np.ceil(valid_generator.samples/valid_generator.batch_size)
hist = model.fit(
train_generator,
epochs=10,
verbose=1,
steps_per_epoch=steps_per_epoch,
validation_data=valid_generator,
validation_steps=val_steps_per_epoch).history
# Get Flowers dataset
data_root = tf.keras.utils.get_file(
'flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
untar=True)
# Create data generator for training and validation
IMAGE_SHAPE = (224, 224)
TRAINING_DATA_DIR = str(data_root)
datagen_kwargs = dict(rescale=1./255, validation_split=.20)
model = tf.keras.Sequential([
hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4",
output_shape=[1280],
trainable=False),
tf.keras.layers.Dropout(0.4),
tf.keras.layers.Dense(train_generator.num_classes, activation='softmax')
])