Created
July 23, 2019 18:09
-
-
Save tlkh/f48c16254c5d31f8328ddfd4a24e244b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import cv2 | |
import tensorflow as tf | |
tf.enable_eager_execution() | |
import tensorflow.keras as keras | |
from tensorflow.keras.preprocessing import image | |
import tensorflow_datasets as tfds | |
from tensorflow.keras.applications.resnet50 import ResNet50 | |
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions | |
from tensorflow.keras.layers import Input | |
# Config to turn on JIT compilation | |
config = tf.ConfigProto() | |
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 | |
sess = tf.Session(config=config) | |
tf.keras.backend.set_session(sess) | |
tf.keras.backend.set_floatx('float16') | |
input_layer = Input(shape=(224,224,3,)) | |
base_model = ResNet50(input_tensor=input_layer, weights='imagenet') | |
model = keras.models.Model(inputs=input_layer, outputs=base_model.output) | |
def format_example(batch): | |
image = batch["image"] | |
image = tf.cast(image, tf.float32) | |
image = tf.image.resize(image, (224, 224)) | |
#image = (image/127.5) - 1 | |
return {"image": image, "label": batch["label"]} | |
def show_images(images, cols, titles): | |
n_images = len(images) | |
images = images.numpy() / 255 | |
images = [cv2.resize(image, (40,40), interpolation=cv2.INTER_LINEAR) for image in images] | |
fig = plt.figure(figsize=(14,6), dpi=48) | |
for n, (image, title) in enumerate(zip(images, titles)): | |
a = fig.add_subplot(n_images//cols, cols, n + 1) | |
plt.imshow(image) | |
a.axis("off") | |
if len(title)>10: title = title[:7]+"..." | |
a.set_title(title) | |
plt.show() | |
# Construct a tf.data.Dataset | |
dataset = tfds.load(name="cats_vs_dogs", split=tfds.Split.TRAIN) | |
# Build your input pipeline | |
dataset = dataset.shuffle(4096) | |
dataset = dataset.map(format_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
dataset = dataset.batch(40) | |
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) | |
for features in dataset.take(10): | |
image_batch, label_batch = features["image"], features["label"] | |
start_time = time.time() | |
preds = model.predict_on_batch(image_batch) | |
inf_time = time.time() | |
pred_labels = [decode_predictions(preds, top=1)] | |
labels = [pred_label[0][1] for pred_label in pred_labels[0]] | |
IPython.display.clear_output(wait=True) | |
show_images(image_batch, cols=10, titles=labels) | |
end_time = time.time() | |
print("\tInference:\t", round(1/(inf_time-start_time), 2), "FPS") | |
print("\tPlotting:\t", round(1/(end_time-inf_time), 2), "FPS") | |
print("\tOverall:\t", round(1/(end_time-start_time), 2), "FPS") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment