Created
May 5, 2020 20:16
-
-
Save mypapit/e3b26787c95caf840e5c16a79327d443 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
"""transfer_learning_with_hub.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/transfer_learning_with_hub.ipynb | |
##### Copyright 2018 The TensorFlow Authors. | |
""" | |
#@title Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""# Transfer learning with TensorFlow Hub | |
<table class="tfo-notebook-buttons" align="left"> | |
<td> | |
<a target="_blank" href="https://www.tensorflow.org/tutorials/images/transfer_learning_with_hub"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a> | |
</td> | |
<td> | |
<a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/transfer_learning_with_hub.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a> | |
</td> | |
<td> | |
<a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/transfer_learning_with_hub.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a> | |
</td> | |
<td> | |
<a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/images/transfer_learning_with_hub.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a> | |
</td> | |
</table> | |
[TensorFlow Hub](http://tensorflow.org/hub) is a way to share pretrained model components. See the [TensorFlow Module Hub](https://tfhub.dev/) for a searchable listing of pre-trained models. This tutorial demonstrates: | |
1. How to use TensorFlow Hub with `tf.keras`. | |
1. How to do image classification using TensorFlow Hub. | |
1. How to do simple transfer learning. | |
## Setup | |
""" | |
import matplotlib.pylab as plt | |
import tensorflow as tf | |
import tensorflow_hub as hub | |
from tensorflow.keras import layers | |
import pandas as pd | |
import seaborn as sns | |
IMAGE_SHAPE = (224, 224) | |
import numpy as np | |
import PIL.Image as Image | |
import time | |
import os | |
"""## Simple transfer learning | |
Using TF Hub it is simple to retrain the top layer of the model to recognize the classes in our dataset. | |
### Dataset | |
For this example you will use the TensorFlow flowers dataset: | |
""" | |
DATA_ROOT = "training_images" | |
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,validation_split=0.25) | |
image_data = image_generator.flow_from_directory(DATA_ROOT, target_size=IMAGE_SHAPE, subset='training') | |
validation_image = image_generator.flow_from_directory(DATA_ROOT, target_size=IMAGE_SHAPE, subset='validation') | |
for image_batch, label_batch in image_data: | |
print("Image batch shape: ", image_batch.shape) | |
print("Label batch shape: ", label_batch.shape) | |
break | |
feature_extractor_url = "https://tfhub.dev/google/imagenet/mobilenet_v2_075_224/feature_vector/4" #@param {type:"string"} | |
"""Create the feature extractor.""" | |
feature_extractor_layer = hub.KerasLayer(feature_extractor_url, | |
input_shape=(224,224,3)) | |
"""It returns a 1280-length vector for each image:""" | |
feature_batch = feature_extractor_layer(image_batch) | |
print(feature_batch.shape) | |
"""Freeze the variables in the feature extractor layer, so that the training only modifies the new classifier layer.""" | |
feature_extractor_layer.trainable = False | |
"""### Attach a classification head | |
Now wrap the hub layer in a `tf.keras.Sequential` model, and add a new classification layer. | |
""" | |
model = tf.keras.Sequential([ | |
feature_extractor_layer, | |
tf.keras.layers.InputLayer(input_shape=IMAGE_SHAPE + (3,)), | |
layers.Dense(image_data.num_classes, activation="softmax") | |
]) | |
model.build((None,)+IMAGE_SIZE+(3,)) | |
model.summary() | |
predictions = model(image_batch) | |
predictions.shape | |
"""### Train the model | |
Use compile to configure the training process: | |
""" | |
model.compile( | |
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), | |
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), | |
metrics=['acc']) | |
"""Now use the `.fit` method to train the model. | |
To keep this example short train just 2 epochs. To visualize the training progress, use a custom callback to log the loss and accuracy of each batch individually, instead of the epoch average. | |
""" | |
class CollectBatchStats(tf.keras.callbacks.Callback): | |
def __init__(self): | |
self.batch_losses = [] | |
self.batch_acc = [] | |
def on_train_batch_end(self, batch, logs=None): | |
self.batch_losses.append(logs['loss']) | |
self.batch_acc.append(logs['acc']) | |
self.model.reset_metrics() | |
batch_size = image_data.batch_size | |
val_batch_size = validation_image.batch_size | |
#batch_size = 100 | |
#val_batch_size = 100 | |
steps_per_epoch = np.ceil(image_data.samples/batch_size) | |
batch_stats_callback = CollectBatchStats() | |
""" | |
#measure time | |
""" | |
starttime = time.time() | |
history = model.fit_generator(image_data, epochs=2, | |
steps_per_epoch=steps_per_epoch, | |
validation_data =validation_image, | |
validation_steps = validation_image.samples //val_batch_size, | |
callbacks = [batch_stats_callback],shuffle=True) | |
"""Now after, even just a few training iterations, we can already see that the model is making progress on the task.""" | |
endtime = time.time() | |
print("[[[Time Elapsed: {0} minutes".format((endtime-starttime) // 60)) | |
plt.figure() | |
plt.ylabel("Loss") | |
plt.xlabel("Training Steps") | |
plt.ylim([0,2]) | |
plt.plot(batch_stats_callback.batch_losses) | |
plt.figure() | |
plt.ylabel("Accuracy") | |
plt.xlabel("Training Steps") | |
plt.ylim([0,1]) | |
plt.plot(batch_stats_callback.batch_acc) | |
plt.figure() | |
plt.ylabel("Accuracy") | |
plt.xlabel("epoch") | |
plt.ylim([0,1]) | |
plt.legend(['train','valid'],loc='upper left') | |
plt.plot(history.history['acc']) | |
plt.plot(history.history['val_acc']) | |
plt.figure() | |
plt.ylabel("Loss") | |
plt.xlabel("epoch") | |
plt.ylim([0,1]) | |
plt.legend(['train','valid'],loc='upper left') | |
plt.plot(history.history['loss']) | |
plt.plot(history.history['val_loss']) | |
"""### Check the predictions | |
To redo the plot from before, first get the ordered list of class names: | |
""" | |
classes = sorted(image_data.class_indices.items(), key=lambda pair:pair[1]) | |
class_names = np.array([key.title() for key, value in classes]) | |
class_names | |
"""Run the image batch through the model and convert the indices to class names.""" | |
predicted_batch = model.predict(image_batch) | |
predicted_id = np.argmax(predicted_batch, axis=-1) | |
predicted_label_batch = class_names[predicted_id] | |
"""Plot the result""" | |
label_id = np.argmax(label_batch, axis=-1) | |
plt.figure(figsize=(10,9)) | |
plt.subplots_adjust(hspace=0.5) | |
for n in range(30): | |
plt.subplot(6,5,n+1) | |
plt.imshow(image_batch[n]) | |
color = "green" if predicted_id[n] == label_id[n] else "red" | |
plt.title(predicted_label_batch[n].title(), color=color) | |
plt.axis('off') | |
_ = plt.suptitle("Model predictions (green: correct, red: incorrect)") | |
"""## Export your model | |
Now that you've trained the model, export it as a saved model: | |
""" | |
t = time.time() | |
export_path = "latest-{}".format(int(t)) | |
model.save(export_path) | |
#model.save("kuih-{}.h5".format(int(t))) | |
print("Saved Model to :{}".format(export_path)) | |
print("#####################") | |
print("Converting to TFLite ....") | |
print("########################") | |
converter = tf.lite.TFLiteConverter.from_saved_model(export_path) | |
tflite_model = converter.convert() | |
open("converted_model.tflite","wb").write(tflite_model) | |
cmodel=tf.keras.models.load_model(export_path) | |
run_model = tf.function(lambda x : cmodel(x)) | |
concrete_func = run_model.get_concrete_function( | |
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype) | |
) | |
''' | |
concrete_func = cmodel.signatures[ | |
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] | |
concrete_func.inputs[0].set_shape([1, 224, 224, 3]) | |
''' | |
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) | |
tflite_model = converter.convert() | |
open("concrete_model.tflite","wb").write(tflite_model) | |
"""Now confirm that we can reload it, and it still gives the same results:""" | |
reloaded = tf.keras.models.load_model(export_path) | |
result_batch = model.predict(image_batch) | |
reloaded_result_batch = reloaded.predict(image_batch) | |
abs(reloaded_result_batch - result_batch).max() | |
"""This saved model can be loaded for inference later, or converted to [TFLite](https://www.tensorflow.org/lite/convert/) or [TFjs](https://github.com/tensorflow/tfjs-converter).""" | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment