Skip to content

Instantly share code, notes, and snippets.

@TheBojda
Created October 21, 2019 08:31
Show Gist options
  • Save TheBojda/4b5a0e7023d3e209321e7a459c208dd0 to your computer and use it in GitHub Desktop.
Save TheBojda/4b5a0e7023d3e209321e7a459c208dd0 to your computer and use it in GitHub Desktop.
TensorFlow image classification example
# TensorFlow image classification example
# based on https://www.tensorflow.org/tutorials/keras/classification
# model generation: https://gist.github.com/TheBojda/f297544cc4864b2b10c2aad965339c58
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
import numpy as np
model = load_model('my_model.h5')
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
def plot_image(i, predictions_array, true_label, img):
predictions_array, true_label, img = predictions_array, true_label[i][0], img[i]
plt.grid(False)
plt.xticks([])
plt.yticks([])
plt.imshow(img, cmap=plt.cm.binary)
predicted_label = np.argmax(predictions_array)
if predicted_label == true_label:
color = 'blue'
else:
color = 'red'
plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
100*np.max(predictions_array),
class_names[true_label]),
color=color)
def plot_value_array(i, predictions_array, true_label):
predictions_array, true_label = predictions_array, true_label[i][0]
plt.grid(False)
plt.xticks(range(10))
plt.yticks([])
thisplot = plt.bar(range(10), predictions_array, color="#777777")
plt.ylim([0, 1])
predicted_label = np.argmax(predictions_array)
thisplot[predicted_label].set_color('red')
thisplot[true_label].set_color('blue')
i = 0
predictions = model.predict(test_images[i:i+1])
print(predictions)
plt.figure(figsize=(6,3))
plt.subplot(1,2,1)
plot_image(i, predictions[0], test_labels, test_images)
plt.subplot(1,2,2)
plot_value_array(i, predictions[0], test_labels)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment