Skip to content

Instantly share code, notes, and snippets.

@rmccorm4
Last active October 12, 2022 06:28
Show Gist options
  • Save rmccorm4/dabccb1f31dbdcf1019a4df431067e52 to your computer and use it in GitHub Desktop.
Save rmccorm4/dabccb1f31dbdcf1019a4df431067e52 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
from typing import Tuple, List
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
def is_fixed(shape: Tuple[int]):
return not is_dynamic(shape)
def is_dynamic(shape: Tuple[int]):
return any(dim is None or dim < 0 for dim in shape)
def setup_binding_shapes(
engine: trt.ICudaEngine,
context: trt.IExecutionContext,
host_inputs: List[np.ndarray],
input_binding_idxs: List[int],
output_binding_idxs: List[int],
):
# Explicitly set the dynamic input shapes, so the dynamic output
# shapes can be computed internally
for host_input, binding_index in zip(host_inputs, input_binding_idxs):
context.set_binding_shape(binding_index, host_input.shape)
assert context.all_binding_shapes_specified
host_outputs = []
device_outputs = []
for binding_index in output_binding_idxs:
output_shape = context.get_binding_shape(binding_index)
# Allocate buffers to hold output results after copying back to host
buffer = np.empty(output_shape, dtype=np.float32)
host_outputs.append(buffer)
# Allocate output buffers on device
device_outputs.append(cuda.mem_alloc(buffer.nbytes))
return host_outputs, device_outputs
def get_binding_idxs(engine: trt.ICudaEngine, profile_index: int):
# Calculate start/end binding indices for current context's profile
num_bindings_per_profile = engine.num_bindings // engine.num_optimization_profiles
start_binding = profile_index * num_bindings_per_profile
end_binding = start_binding + num_bindings_per_profile
# Separate input and output binding indices for convenience
input_binding_idxs = []
output_binding_idxs = []
for binding_index in range(start_binding, end_binding):
if engine.binding_is_input(binding_index):
input_binding_idxs.append(binding_index)
else:
output_binding_idxs.append(binding_index)
return input_binding_idxs, output_binding_idxs
def load_engine(filename: str):
# Load serialized engine file into memory
with open(filename, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
return runtime.deserialize_cuda_engine(f.read())
def get_random_inputs(
engine: trt.ICudaEngine,
context: trt.IExecutionContext,
input_binding_idxs: List[int],
):
# Input data for inference
host_inputs = []
for binding_index in input_binding_idxs:
# If input shape is fixed, we'll just use it
input_shape = context.get_binding_shape(binding_index)
# If input shape is dynamic, we'll arbitrarily select one of the
# the min/opt/max shapes from our optimization profile
if is_dynamic(input_shape):
profile_index = context.active_optimization_profile
profile_shapes = engine.get_profile_shape(profile_index, binding_index)
# 0=min, 1=opt, 2=max, or choose any shape, (min <= shape <= max)
input_shape = profile_shapes[1]
host_inputs.append(np.random.random(input_shape).astype(np.float32))
return host_inputs
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-e", "--engine", required=True, type=str, help="Path to TensorRT engine file."
)
args = parser.parse_args()
# Load a serialized engine into memory
engine = load_engine(args.engine)
# Create context, this can be re-used
context = engine.create_execution_context()
# Profile 0 (first profile) is used by default
context.active_optimization_profile = 0
# These binding_idxs can change if either the context or the
# active_optimization_profile are changed
input_binding_idxs, output_binding_idxs = get_binding_idxs(
engine, context.active_optimization_profile
)
# Generate random inputs based on profile shapes
host_inputs = get_random_inputs(engine, context, input_binding_idxs)
print("Input Shapes: {}".format([inp.shape for inp in host_inputs]))
# Allocate device memory for inputs. This can be easily re-used if the
# input shapes don't change
device_inputs = [cuda.mem_alloc(h_input.nbytes) for h_input in host_inputs]
# Copy host inputs to device, this needs to be done for each new input
for h_input, d_input in zip(host_inputs, device_inputs):
cuda.memcpy_htod(d_input, h_input)
# This needs to be called everytime your input shapes change
# If your inputs are always the same shape (same batch size, etc.),
# then you will only need to call this once
host_outputs, device_outputs = setup_binding_shapes(
engine, context, host_inputs, input_binding_idxs, output_binding_idxs,
)
print("Output Shapes: {}".format([out.shape for out in host_outputs]))
# Bindings are a list of device pointers for inputs and outputs
bindings = device_inputs + device_outputs
# Inference
context.execute_v2(bindings)
# Copy outputs back to host to view results
for h_output, d_output in zip(host_outputs, device_outputs):
cuda.memcpy_dtoh(h_output, d_output)
# View outputs
print(host_outputs)
# Cleanup (Can also use context managers instead)
del context
del engine
if __name__ == "__main__":
main()
@rmccorm4
Copy link
Author

rmccorm4 commented Jun 29, 2020

Hi @TerryBryant,

Please refer to this code block: https://gist.github.com/rmccorm4/dabccb1f31dbdcf1019a4df431067e52#file-dynamic_shape_inference-py-L28-L33

When all input binding shapes for an execution context have been specified (context.all_binding_shapes_specified==True), TensorRT should calculate the output binding shapes for that context automatically under the hood.

You can verify this by checking the context output binding shapes before and after setting the input binding shapes.

@TerryBryant
Copy link

Hi @TerryBryant,

Please refer to this code block: https://gist.github.com/rmccorm4/dabccb1f31dbdcf1019a4df431067e52#file-dynamic_shape_inference-py-L28-L33

When all input binding shapes for an execution context have been specified (context.all_binding_shapes_specified==True), TensorRT should calculate the output binding shapes for that context automatically under the hood.

You can verify this by checking the context output binding shapes before and after setting the input binding shapes.

I see, thanks a lot !

@TerryBryant
Copy link

Hi, I'm trying to run this script in multi thread. I want to load engine once, and create multi contexts for each thread, because different thread has different input size, so the binding shape is changing. But after I write it in this way, error occurs as follows,

[TensorRT] ERROR: Profile 0 has been chosen by another IExecutionContext. Use another profileIndex or destroy the IExecutionContext that use this profile.
[TensorRT] WARNING: Could not set default profile 0 for execution context. Profile index must be set explicitly.
[TensorRT] ERROR: Profile 0 has been chosen by another IExecutionContext. Use another profileIndex or destroy the IExecutionContext that use this profile.

I searched the problem, and only find some c++ docs, which tells that,

If the associated CUDA engine has dynamic inputs, this method must be called at least once with a unique profileIndex before calling execute or enqueue (i.e. the profile index may not be in use by another execution context that has not been destroyed yet). For the first execution context that is created for an engine, setOptimizationProfile(0) is called implicitly.

But I still don't know how to write the multi thread script. Could you help me? Thanks in advance!

@rmccorm4
Copy link
Author

rmccorm4 commented Jul 9, 2020

Hi @TerryBryant,

[TensorRT] ERROR: Profile 0 has been chosen by another IExecutionContext. Use another profileIndex or destroy the IExecutionContext that use this profile.

You're likely using the same profile index on each thread, which isn't currently allowed.


At engine building time (single thread), you'll need to create at least as many optimization profiles as the number of threads you expect to be running simultaneously.

At runtime, for each thread you will likely need to do something like:

  1. Create a new execution context
  2. Assign the execution context an optimization profile that's not currently in use by another thread

For example, let's say you want to run 4 threads.

# --- Build time --- #

# Create 4 opt profiles ...
profile0 = builder.create_optimization_profile()
profile0.set_shape(...)
builder_config.add_optimization_profile(profile0) # profile_index=0

profile1 = builder.create_optimization_profile()
profile1.set_shape(...)
builder_config.add_optimization_profile(profile1) # profile_index=1

profile2 = builder.create_optimization_profile()
profile2.set_shape(...)
builder_config.add_optimization_profile(profile2) # profile_index=2

profile3 = builder.create_optimization_profile()
profile3.set_shape(...)
builder_config.add_optimization_profile(profile3) # profile_index=3
...
engine = builder.build_engine(network, builder_config)
# --- Inference time ---

# Create an execution context with a unique optimization profile for each thread

# thread 0
context0 = engine.create_execution_context()
context0.active_optimization_profile = 0

# thread 1
context1 = engine.create_execution_context()
context1.active_optimization_profile = 1

# thread 2
context2 = engine.create_execution_context()
context2.active_optimization_profile = 2

# thread 3
context3 = engine.create_execution_context()
context3.active_optimization_profile = 3

@TerryBryant
Copy link

Hi, @rmccorm4 ,
Thank you for your sample code, I think I get it. But I still expect a better solution, because the input size of all my data are in the same range, which means I only need one kind of profile and all data can share. Also in inference time, it's not so convenient to assign different profile for each thread, because the thread may run in random.

Waiting for more instructions. Thanks.

@rmccorm4
Copy link
Author

rmccorm4 commented Jul 10, 2020

Hi @TerryBryant,

You can define several profiles covering the same range of shapes. I agree it would be nice if you could re-use the same profile by multiple threads simultaneously, but I don't believe that's currently possible.

@TerryBryant
Copy link

Hi @TerryBryant,

You can define several profiles covering the same range of shapes. I agree it would be nice if you could re-use the same profile, but I don't believe that's currently possible.

Ok, I got it. Thank you.

@TerryBryant
Copy link

TerryBryant commented Jul 10, 2020

Hi @TerryBryant,

[TensorRT] ERROR: Profile 0 has been chosen by another IExecutionContext. Use another profileIndex or destroy the IExecutionContext that use this profile.

You're likely using the same profile index on each thread, which isn't currently allowed.

At engine building time (single thread), you'll need to create at least as many optimization profiles as the number of threads you expect to be running simultaneously.

At runtime, for each thread you will likely need to do something like:

  1. Create a new execution context
  2. Assign the execution context an optimization profile that's not currently in use by another thread

For example, let's say you want to run 4 threads.

# --- Build time --- #

# Create 4 opt profiles ...
profile0 = builder.create_optimization_profile()
profile0.set_shape(...)
builder_config.add_optimization_profile(profile0) # profile_index=0

profile1 = builder.create_optimization_profile()
profile1.set_shape(...)
builder_config.add_optimization_profile(profile1) # profile_index=1

profile2 = builder.create_optimization_profile()
profile2.set_shape(...)
builder_config.add_optimization_profile(profile2) # profile_index=2

profile3 = builder.create_optimization_profile()
profile3.set_shape(...)
builder_config.add_optimization_profile(profile3) # profile_index=3
...
engine = builder.build_engine(network, builder_config)
# --- Inference time ---

# Create an execution context with a unique optimization profile for each thread

# thread 0
context0 = engine.create_execution_context()
context0.active_optimization_profile = 0

# thread 1
context1 = engine.create_execution_context()
context1.active_optimization_profile = 1

# thread 2
context2 = engine.create_execution_context()
context2.active_optimization_profile = 2

# thread 3
context3 = engine.create_execution_context()
context3.active_optimization_profile = 3

I've tried this solution, two problems occured:
1, the serialized engine file becomes very huge, due to I add 10 profiles
2, this kind of log appears, but I can still run the inference, don't know whether it's a warning message, it goes wrong

[TensorRT] WARNING: Total space of persistent layer space is 524160 on host and 3773253120 on device
[TensorRT] ERROR: ../rtSafe/safeRuntime.cpp (25) - Cuda Error in allocate: 2 (out of memory)
[TensorRT] ERROR: FAILED_ALLOCATION: std::exception

So I think it's a temporary solution, hope you and your official tensorrt team can take this multi thread problem into consideration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment