Skip to content

Instantly share code, notes, and snippets.

@RomanGirin
Created January 17, 2020 10:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RomanGirin/e9792c3a34a78a6d6e0729041cbc81b2 to your computer and use it in GitHub Desktop.
Save RomanGirin/e9792c3a34a78a6d6e0729041cbc81b2 to your computer and use it in GitHub Desktop.
import keras
import numpy as np
import tensorrt as trt
import common # it uses the same common as in standard tensorrt samples (it typically is in /usr/src/tensorrt/samples/python/common.py)
import sys, os
sys.path.insert(1, os.path.join(sys.path[0], ".."))
# This import causes pycuda to automatically manage CUDA context creation and cleanup.
import pycuda.autoinit
# You can set the logger severity higher to suppress messages (or lower to display more messages).
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
class ModelData(object):
MODEL_FILE = "lenet5.uff" # to create the uff follow instructions listed here https://docs.nvidia.com/deeplearning/sdk/tensorrt-sample-support-guide/index.html#end_to_end_tensorflow_mnist
ENGINE_FILE = "lenet5.engine"
INPUT_NAME ="input_1"
INPUT_SHAPE = (1, 28, 28)
OUTPUT_NAME = "dense_1/Softmax"
def get_engine(model_file, engine_file_path=""):
"""Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
def build_engine():
# For more information on TRT basics, refer to the introductory samples.
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser:
builder.max_workspace_size = common.GiB(1)
# Parse the Uff Network
parser.register_input(ModelData.INPUT_NAME, ModelData.INPUT_SHAPE)
parser.register_output(ModelData.OUTPUT_NAME)
parser.parse(model_file, network)
# Build and return an engine.
print('Completed parsing of UFF file')
print('Building an engine from file {}; this may take a while...'.format(model_file))
engine = builder.build_cuda_engine(network)
print("Completed creating Engine")
with open(engine_file_path, "wb") as f:
f.write(engine.serialize())
return engine
if os.path.exists(engine_file_path):
# If a serialized engine exists, use it instead of building an engine.
print("Reading engine from file {}".format(engine_file_path))
with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
return runtime.deserialize_cuda_engine(f.read())
else:
return build_engine()
def prepare_input():
# Import the data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Uncomment this if you need to take a look at picture we going to process
# import matplotlib.pyplot as plt
# plt.imshow(x_test[0], cmap='gray', interpolation='none')
# plt.title("Class {}".format(y_test[0]))
# plt.show()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Reshape the data
NUM_TRAIN = 60000
NUM_TEST = 10000
x_train = np.reshape(x_train, (NUM_TRAIN, 28, 28, 1))
x_test = np.reshape(x_test, (NUM_TEST, 28, 28, 1))
# For this quick demo just take the first data instance and process it via the engine
image = np.expand_dims(x_test[0], axis=0)
# Convert the image to row-major order, also known as "C order":
image = np.array(image, dtype=np.float32, order='C')
return image
def main():
model_path = os.path.join(os.path.dirname(__file__), "models")
model_file = os.path.join(model_path, ModelData.MODEL_FILE)
engine_file_path = os.path.join(model_path, ModelData.ENGINE_FILE)
input = prepare_input()
with get_engine(model_file, engine_file_path) as engine, engine.create_execution_context() as context:
inputs, outputs, bindings, stream = common.allocate_buffers(engine)
# Do inference
print('Running inference on image...')
# Set host input to the image. The common.do_inference function will copy the input to the GPU before executing.
inputs[0].host = input
trt_outputs = common.do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
print(trt_outputs)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment