Created
May 29, 2019 19:11
-
-
Save nhubbard/32b873a43451f390b921d32de6421896 to your computer and use it in GitHub Desktop.
Benchmarking code for VTK model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
import cv2 as cv | |
from time import time | |
import sys, os | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
cap = cv.VideoCapture(int(sys.argv[2])) | |
lat = [] | |
# Read the graph. | |
with open(sys.argv[1], 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
with tf.Session() as sess: | |
# Restore session | |
sess.graph.as_default() | |
tf.import_graph_def(graph_def, name='') | |
# Read and preprocess an image. | |
while not cv.waitKey(1) & 0xFF == ord("q"): | |
status, img = cap.read() | |
rows = img.shape[0] | |
cols = img.shape[1] | |
inp = cv.resize(img, (300, 300)) | |
inp = inp[:, :, [2, 1, 0]] # BGR2RGB | |
start_time = time() | |
# Run the model | |
out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'), | |
sess.graph.get_tensor_by_name('detection_scores:0'), | |
sess.graph.get_tensor_by_name('detection_boxes:0'), | |
sess.graph.get_tensor_by_name('detection_classes:0')], | |
feed_dict={'image_tensor:0': inp.reshape(1, inp.shape[0], inp.shape[1], 3)}) | |
end_time = time() | |
if len(lat) <= 1000: | |
lat.append(1 / (end_time - start_time)) | |
else: | |
break | |
# Visualize detected bounding boxes. | |
num_detections = int(out[0][0]) | |
for i in range(num_detections): | |
classId = int(out[3][0][i]) | |
score = float(out[1][0][i]) | |
bbox = [float(v) for v in out[2][0][i]] | |
if score > 0.3: | |
x = bbox[1] * cols | |
y = bbox[0] * rows | |
right = bbox[3] * cols | |
bottom = bbox[2] * rows | |
cv.rectangle(img, (int(x), int(y)), (int(right), int(bottom)), (125, 255, 51), thickness=2) | |
cv.imshow('TensorFlow MobileNet-SSD', img) | |
cap.release() | |
filename = sys.argv[1].split(".")[0] | |
with open(filename + "_unopt_fps.csv", "w") as output: | |
output.write(",".join(str(e) for e in lat[1:])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
import cv2 as cv | |
import tensorflow.contrib.tensorrt as trt | |
from time import time | |
import sys | |
import os | |
from termcolor import colored | |
print(colored("Opening handle on camera /dev/video{s}".format(s=sys.argv[2]), "green", attrs=["bold"])) | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
cap = cv.VideoCapture(int(sys.argv[2])) | |
lat = [] | |
print(colored("Live-optimizing graph with TensorRT", "green", attrs=["bold"])) | |
# Read the graph. | |
with open(sys.argv[1], 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
trt_graph = trt.create_inference_graph( | |
input_graph_def=graph_def, | |
outputs=["num_detections:0", "detection_scores:0", "detection_boxes:0", "detection_classes:0"], | |
precision_mode="int8" | |
) | |
print(colored("Launching TensorFlow session", "green", attrs=["bold"])) | |
with tf.Session() as sess: | |
# Restore session | |
sess.graph.as_default() | |
tf.import_graph_def(trt_graph, name='') | |
# Read and preprocess an image. | |
while not cv.waitKey(1) & 0xFF == ord("q"): | |
status, img = cap.read() | |
rows = img.shape[0] | |
cols = img.shape[1] | |
inp = cv.resize(img, (300, 300)) | |
inp = inp[:, :, [2, 1, 0]] # BGR2RGB | |
start_time = time() | |
# Run the model | |
out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'), | |
sess.graph.get_tensor_by_name('detection_scores:0'), | |
sess.graph.get_tensor_by_name('detection_boxes:0'), | |
sess.graph.get_tensor_by_name('detection_classes:0')], | |
feed_dict={'image_tensor:0': inp.reshape(1, inp.shape[0], inp.shape[1], 3)}) | |
end_time = time() | |
if len(lat) <= 1000: | |
lat.append(1 / (end_time - start_time)) | |
else: | |
break | |
# Visualize detected bounding boxes. | |
num_detections = int(out[0][0]) | |
for i in range(num_detections): | |
classId = int(out[3][0][i]) | |
score = float(out[1][0][i]) | |
bbox = [float(v) for v in out[2][0][i]] | |
if score > 0.7: | |
x = bbox[1] * cols | |
y = bbox[0] * rows | |
right = bbox[3] * cols | |
bottom = bbox[2] * rows | |
cv.rectangle(img, (int(x), int(y)), (int(right), int(bottom)), (125, 255, 51), thickness=2) | |
cv.putText(img, "Try me!", (10, 100), cv.FONT_HERSHEY_SIMPLEX, 2, (125, 255, 51), 4, cv.LINE_AA) | |
cv.putText(img, "Put a ball or hatch into the frame.", (10, 150), cv.FONT_HERSHEY_SIMPLEX, 1, (125, 255, 51), 2, cv.LINE_AA) | |
cv.imshow('TensorFlow MobileNet-SSD', img) | |
print(colored("Closing camera session", "green", attrs=["bold"])) | |
cap.release() | |
filename = sys.argv[1].split(".")[0] | |
print(colored("Writing initial statistics", "green", attrs=["bold"])) | |
with open(filename + "_trt_fps.csv", "w") as output: | |
output.write(",".join(str(e) for e in lat)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment