Created
June 30, 2020 08:28
-
-
Save jaemin93/e8607ff4c86f8f9211dfb1ceafb527e3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_engine(onnx_file_path, engine_file_path=""): | |
def build_engine(): | |
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser: | |
builder.max_workspace_size = 1 << 28 | |
builder.max_batch_size = 1 | |
print('Loading ONNX file from path {}...'.format(onnx_file_path)) | |
with open(onnx_file_path, 'rb') as model: | |
print('Beginning ONNX file parsing') | |
if not parser.parse(model.read()): | |
print('ERROR: Failed to parse the ONNX file') | |
for error in range(parser.num_errors): | |
print(parser.get_error(error)) | |
return None | |
network.get_input(0).shape = [1, 1, 28, 28] | |
print('Completed parsing of ONNX file') | |
print('Building an engine from file {}; this may take a while...'.format(onnx_file_path)) | |
engine = builder.build_cuda_engine(network) | |
print("Completed creating Engine") | |
with open(engine_file_path, "wb") as f: | |
f.write(engine.serialize()) | |
return engine | |
if os.path.exists(engine_file_path): | |
# If a serialized engine exists, use it instead of building an engine. | |
print("Reading engine from file {}".format(engine_file_path)) | |
with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: | |
return runtime.deserialize_cuda_engine(f.read()) | |
else: | |
return build_engine() | |
if __name__ == "__main__": | |
onnx_file_path = 'test.onnx' | |
engine_file_path = 'test.engine' | |
with get_engine(onnx_file_path, engine_file_path) as engine, engine.create_execution_context() as context: | |
inputs, outputs, bindings, stream = allocate_buffers(engine) | |
# Do inference | |
for i in tqdm(range(100000)): | |
test = torch.randn(28, 28, 1) | |
test = np.array(test) | |
inputs[0].host = test | |
trt_outputs = do_inference_v2(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment