Skip to content

Instantly share code, notes, and snippets.

@rmccorm4
Last active May 10, 2020 22:30
Show Gist options
  • Save rmccorm4/438b3bcee8e3718c3b6557363332f569 to your computer and use it in GitHub Desktop.
Save rmccorm4/438b3bcee8e3718c3b6557363332f569 to your computer and use it in GitHub Desktop.
def is_dynamic(shape: Tuple[int]):
return any(dim is None or dim < 0 for dim in shape)
def get_random_inputs(
engine: trt.ICudaEngine,
context: trt.IExecutionContext,
input_binding_idxs: List[int],
):
# Input data for inference
host_inputs = []
for binding_index in input_binding_idxs:
# If input shape is fixed, we'll just use it
input_shape = context.get_binding_shape(binding_index)
# If input shape is dynamic, we'll arbitrarily select one of the
# the min/opt/max shapes from our optimization profile
if is_dynamic(input_shape):
profile_index = context.active_optimization_profile
profile_shapes = engine.get_profile_shape(profile_index, binding_index)
# 0=min, 1=opt, 2=max, or choose any shape, (min <= shape <= max)
input_shape = profile_shapes[1]
host_inputs.append(np.random.random(input_shape).astype(np.float32))
return host_inputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment