Skip to content

Instantly share code, notes, and snippets.

@rmccorm4
Last active May 11, 2020 18:22
Show Gist options
  • Save rmccorm4/bb4a0d505ca2d355b39739de9497eee2 to your computer and use it in GitHub Desktop.
Save rmccorm4/bb4a0d505ca2d355b39739de9497eee2 to your computer and use it in GitHub Desktop.
def setup_binding_shapes(
engine: trt.ICudaEngine,
context: trt.IExecutionContext,
host_inputs: List[np.ndarray],
input_binding_idxs: List[int],
output_binding_idxs: List[int],
):
# Explicitly set the dynamic input shapes, so the dynamic output
# shapes can be computed internally
for host_input, binding_index in zip(host_inputs, input_binding_idxs):
context.set_binding_shape(binding_index, host_input.shape)
assert context.all_binding_shapes_specified
host_outputs = []
device_outputs = []
for binding_index in output_binding_idxs:
output_shape = context.get_binding_shape(binding_index)
# Allocate buffers to hold output results after copying back to host
buffer = np.empty(output_shape, dtype=np.float32)
host_outputs.append(buffer)
# Allocate output buffers on device
device_outputs.append(cuda.mem_alloc(buffer.nbytes))
return host_outputs, device_outputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment