Created
January 17, 2020 10:59
-
-
Save RomanGirin/e9792c3a34a78a6d6e0729041cbc81b2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import keras | |
import numpy as np | |
import tensorrt as trt | |
import common # it uses the same common as in standard tensorrt samples (it typically is in /usr/src/tensorrt/samples/python/common.py) | |
import sys, os | |
sys.path.insert(1, os.path.join(sys.path[0], "..")) | |
# This import causes pycuda to automatically manage CUDA context creation and cleanup. | |
import pycuda.autoinit | |
# You can set the logger severity higher to suppress messages (or lower to display more messages). | |
TRT_LOGGER = trt.Logger(trt.Logger.WARNING) | |
class ModelData(object): | |
MODEL_FILE = "lenet5.uff" # to create the uff follow instructions listed here https://docs.nvidia.com/deeplearning/sdk/tensorrt-sample-support-guide/index.html#end_to_end_tensorflow_mnist | |
ENGINE_FILE = "lenet5.engine" | |
INPUT_NAME ="input_1" | |
INPUT_SHAPE = (1, 28, 28) | |
OUTPUT_NAME = "dense_1/Softmax" | |
def get_engine(model_file, engine_file_path=""): | |
"""Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it.""" | |
def build_engine(): | |
# For more information on TRT basics, refer to the introductory samples. | |
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser: | |
builder.max_workspace_size = common.GiB(1) | |
# Parse the Uff Network | |
parser.register_input(ModelData.INPUT_NAME, ModelData.INPUT_SHAPE) | |
parser.register_output(ModelData.OUTPUT_NAME) | |
parser.parse(model_file, network) | |
# Build and return an engine. | |
print('Completed parsing of UFF file') | |
print('Building an engine from file {}; this may take a while...'.format(model_file)) | |
engine = builder.build_cuda_engine(network) | |
print("Completed creating Engine") | |
with open(engine_file_path, "wb") as f: | |
f.write(engine.serialize()) | |
return engine | |
if os.path.exists(engine_file_path): | |
# If a serialized engine exists, use it instead of building an engine. | |
print("Reading engine from file {}".format(engine_file_path)) | |
with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: | |
return runtime.deserialize_cuda_engine(f.read()) | |
else: | |
return build_engine() | |
def prepare_input(): | |
# Import the data | |
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() | |
# Uncomment this if you need to take a look at picture we going to process | |
# import matplotlib.pyplot as plt | |
# plt.imshow(x_test[0], cmap='gray', interpolation='none') | |
# plt.title("Class {}".format(y_test[0])) | |
# plt.show() | |
x_train, x_test = x_train / 255.0, x_test / 255.0 | |
# Reshape the data | |
NUM_TRAIN = 60000 | |
NUM_TEST = 10000 | |
x_train = np.reshape(x_train, (NUM_TRAIN, 28, 28, 1)) | |
x_test = np.reshape(x_test, (NUM_TEST, 28, 28, 1)) | |
# For this quick demo just take the first data instance and process it via the engine | |
image = np.expand_dims(x_test[0], axis=0) | |
# Convert the image to row-major order, also known as "C order": | |
image = np.array(image, dtype=np.float32, order='C') | |
return image | |
def main(): | |
model_path = os.path.join(os.path.dirname(__file__), "models") | |
model_file = os.path.join(model_path, ModelData.MODEL_FILE) | |
engine_file_path = os.path.join(model_path, ModelData.ENGINE_FILE) | |
input = prepare_input() | |
with get_engine(model_file, engine_file_path) as engine, engine.create_execution_context() as context: | |
inputs, outputs, bindings, stream = common.allocate_buffers(engine) | |
# Do inference | |
print('Running inference on image...') | |
# Set host input to the image. The common.do_inference function will copy the input to the GPU before executing. | |
inputs[0].host = input | |
trt_outputs = common.do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream) | |
print(trt_outputs) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment