Skip to content

Instantly share code, notes, and snippets.

Created July 23, 2019 18:09
Show Gist options
  • Save tlkh/f48c16254c5d31f8328ddfd4a24e244b to your computer and use it in GitHub Desktop.
Save tlkh/f48c16254c5d31f8328ddfd4a24e244b to your computer and use it in GitHub Desktop.
import time
import numpy as np
import matplotlib.pyplot as plt
import cv2
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.preprocessing import image
import tensorflow_datasets as tfds
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
from tensorflow.keras.layers import Input
# Config to turn on JIT compilation
config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
sess = tf.Session(config=config)
input_layer = Input(shape=(224,224,3,))
base_model = ResNet50(input_tensor=input_layer, weights='imagenet')
model = keras.models.Model(inputs=input_layer, outputs=base_model.output)
def format_example(batch):
image = batch["image"]
image = tf.cast(image, tf.float32)
image = tf.image.resize(image, (224, 224))
#image = (image/127.5) - 1
return {"image": image, "label": batch["label"]}
def show_images(images, cols, titles):
n_images = len(images)
images = images.numpy() / 255
images = [cv2.resize(image, (40,40), interpolation=cv2.INTER_LINEAR) for image in images]
fig = plt.figure(figsize=(14,6), dpi=48)
for n, (image, title) in enumerate(zip(images, titles)):
a = fig.add_subplot(n_images//cols, cols, n + 1)
if len(title)>10: title = title[:7]+"..."
# Construct a
dataset = tfds.load(name="cats_vs_dogs", split=tfds.Split.TRAIN)
# Build your input pipeline
dataset = dataset.shuffle(4096)
dataset =,
dataset = dataset.batch(40)
dataset = dataset.prefetch(
for features in dataset.take(10):
image_batch, label_batch = features["image"], features["label"]
start_time = time.time()
preds = model.predict_on_batch(image_batch)
inf_time = time.time()
pred_labels = [decode_predictions(preds, top=1)]
labels = [pred_label[0][1] for pred_label in pred_labels[0]]
show_images(image_batch, cols=10, titles=labels)
end_time = time.time()
print("\tInference:\t", round(1/(inf_time-start_time), 2), "FPS")
print("\tPlotting:\t", round(1/(end_time-inf_time), 2), "FPS")
print("\tOverall:\t", round(1/(end_time-start_time), 2), "FPS")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment