Skip to content

Instantly share code, notes, and snippets.

@nhubbard
Created May 29, 2019 19:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nhubbard/32b873a43451f390b921d32de6421896 to your computer and use it in GitHub Desktop.
Save nhubbard/32b873a43451f390b921d32de6421896 to your computer and use it in GitHub Desktop.
Benchmarking code for VTK model
import numpy as np
import tensorflow as tf
import cv2 as cv
from time import time
import sys, os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
cap = cv.VideoCapture(int(sys.argv[2]))
lat = []
# Read the graph.
with open(sys.argv[1], 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Session() as sess:
# Restore session
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
# Read and preprocess an image.
while not cv.waitKey(1) & 0xFF == ord("q"):
status, img = cap.read()
rows = img.shape[0]
cols = img.shape[1]
inp = cv.resize(img, (300, 300))
inp = inp[:, :, [2, 1, 0]] # BGR2RGB
start_time = time()
# Run the model
out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
sess.graph.get_tensor_by_name('detection_scores:0'),
sess.graph.get_tensor_by_name('detection_boxes:0'),
sess.graph.get_tensor_by_name('detection_classes:0')],
feed_dict={'image_tensor:0': inp.reshape(1, inp.shape[0], inp.shape[1], 3)})
end_time = time()
if len(lat) <= 1000:
lat.append(1 / (end_time - start_time))
else:
break
# Visualize detected bounding boxes.
num_detections = int(out[0][0])
for i in range(num_detections):
classId = int(out[3][0][i])
score = float(out[1][0][i])
bbox = [float(v) for v in out[2][0][i]]
if score > 0.3:
x = bbox[1] * cols
y = bbox[0] * rows
right = bbox[3] * cols
bottom = bbox[2] * rows
cv.rectangle(img, (int(x), int(y)), (int(right), int(bottom)), (125, 255, 51), thickness=2)
cv.imshow('TensorFlow MobileNet-SSD', img)
cap.release()
filename = sys.argv[1].split(".")[0]
with open(filename + "_unopt_fps.csv", "w") as output:
output.write(",".join(str(e) for e in lat[1:]))
import numpy as np
import tensorflow as tf
import cv2 as cv
import tensorflow.contrib.tensorrt as trt
from time import time
import sys
import os
from termcolor import colored
print(colored("Opening handle on camera /dev/video{s}".format(s=sys.argv[2]), "green", attrs=["bold"]))
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
cap = cv.VideoCapture(int(sys.argv[2]))
lat = []
print(colored("Live-optimizing graph with TensorRT", "green", attrs=["bold"]))
# Read the graph.
with open(sys.argv[1], 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
trt_graph = trt.create_inference_graph(
input_graph_def=graph_def,
outputs=["num_detections:0", "detection_scores:0", "detection_boxes:0", "detection_classes:0"],
precision_mode="int8"
)
print(colored("Launching TensorFlow session", "green", attrs=["bold"]))
with tf.Session() as sess:
# Restore session
sess.graph.as_default()
tf.import_graph_def(trt_graph, name='')
# Read and preprocess an image.
while not cv.waitKey(1) & 0xFF == ord("q"):
status, img = cap.read()
rows = img.shape[0]
cols = img.shape[1]
inp = cv.resize(img, (300, 300))
inp = inp[:, :, [2, 1, 0]] # BGR2RGB
start_time = time()
# Run the model
out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
sess.graph.get_tensor_by_name('detection_scores:0'),
sess.graph.get_tensor_by_name('detection_boxes:0'),
sess.graph.get_tensor_by_name('detection_classes:0')],
feed_dict={'image_tensor:0': inp.reshape(1, inp.shape[0], inp.shape[1], 3)})
end_time = time()
if len(lat) <= 1000:
lat.append(1 / (end_time - start_time))
else:
break
# Visualize detected bounding boxes.
num_detections = int(out[0][0])
for i in range(num_detections):
classId = int(out[3][0][i])
score = float(out[1][0][i])
bbox = [float(v) for v in out[2][0][i]]
if score > 0.7:
x = bbox[1] * cols
y = bbox[0] * rows
right = bbox[3] * cols
bottom = bbox[2] * rows
cv.rectangle(img, (int(x), int(y)), (int(right), int(bottom)), (125, 255, 51), thickness=2)
cv.putText(img, "Try me!", (10, 100), cv.FONT_HERSHEY_SIMPLEX, 2, (125, 255, 51), 4, cv.LINE_AA)
cv.putText(img, "Put a ball or hatch into the frame.", (10, 150), cv.FONT_HERSHEY_SIMPLEX, 1, (125, 255, 51), 2, cv.LINE_AA)
cv.imshow('TensorFlow MobileNet-SSD', img)
print(colored("Closing camera session", "green", attrs=["bold"]))
cap.release()
filename = sys.argv[1].split(".")[0]
print(colored("Writing initial statistics", "green", attrs=["bold"]))
with open(filename + "_trt_fps.csv", "w") as output:
output.write(",".join(str(e) for e in lat))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment