Skip to content

Instantly share code, notes, and snippets.

@zengqingfu1442
Forked from Shiina18/grpc_shm_client.py
Created December 15, 2023 08:05
Show Gist options
  • Save zengqingfu1442/402cf0ef6976f07d5e3f56f84be611ee to your computer and use it in GitHub Desktop.
Save zengqingfu1442/402cf0ef6976f07d5e3f56f84be611ee to your computer and use it in GitHub Desktop.
import collections
import uuid
from typing import Dict
import numpy as np
import tritonclient
import tritonclient.grpc as grpcclient
import tritonclient.utils.shared_memory as shm
ShmHandle = collections.namedtuple(
'ShmHandle', 'shared_data, shm_handle, shm_name, names, byte_sizes'
)
# https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md#datatypes
DTYPE_np2triton = {
np.int64: 'INT64',
np.int32: 'INT32',
np.int16: 'INT16',
np.int8: 'INT8',
np.float32: 'FP32',
np.float16: 'FP16',
np.bool: 'BOOL'
}
class TritonClient:
"""CPU, gRPC only"""
def __init__(self, port: int):
self.client = grpcclient.InferenceServerClient(url=f'localhost:{port}')
def _get_handle(self, d, is_input):
[*names], [*data] = zip(*d.items())
byte_sizes = [array.size * array.itemsize for array in data]
dtypes = []
if is_input:
for array in data:
for np_dtype, triton_dtype in DTYPE_np2triton.items():
# cannot use DTYPE_np2triton[array.dtype]
if array.dtype == np_dtype:
dtypes.append(triton_dtype)
break
else:
raise ValueError(f'Unsupported datatype {array.dtype}')
# TODO: is it necessary to use uuid?
triton_shm_name = uuid.uuid1().hex
shm_key = '/' + triton_shm_name
shm_handle = shm.create_shared_memory_region(
triton_shm_name, shm_key, byte_size=sum(byte_sizes)
)
self.client.register_system_shared_memory(
triton_shm_name, shm_key, byte_size=sum(byte_sizes)
)
offset = 0
for i, array in enumerate(data):
shm.set_shared_memory_region(shm_handle, [array], offset=offset)
offset += byte_sizes[i]
shared_data = []
offset = 0
for i, array in enumerate(data):
if dtypes:
shared_data.append(grpcclient.InferInput(names[i], list(array.shape), dtypes[i]))
else:
shared_data.append(grpcclient.InferRequestedOutput(names[i]))
shared_data[-1].set_shared_memory(triton_shm_name, byte_sizes[i], offset=offset)
offset += byte_sizes[i]
return ShmHandle(
shared_data=shared_data, shm_handle=shm_handle, shm_name=triton_shm_name,
names=names, byte_sizes=byte_sizes
)
def infer(
self,
model_name: str,
input_dict: Dict[str, np.ndarray],
output_dict: Dict[str, np.ndarray],
model_version: str = '',
) -> Dict[str, np.ndarray]:
"""
Parameters
----------
model_name
input_dict
{input_name: array}
output_dict
{output_name: array}
The array is only used to get the shape and the memory size, so
a random array is enough (np.empty).
model_version
The default value is an empty string which means then the server
will choose a version based on the model and internal policy.
Returns
-------
final_output_dict
{output_name: array}
"""
input_handle = self._get_handle(input_dict, is_input=True)
output_handle = self._get_handle(output_dict, is_input=False)
results = self.client.infer(
model_name=model_name,
model_version=model_version,
inputs=input_handle.shared_data,
outputs=output_handle.shared_data
)
final_output_dict = {}
offset = 0
for i, name in enumerate(output_handle.names):
value_pb = results.get_output(name)
assert value_pb is not None, f'Missing output `{name}`'
value = shm.get_contents_as_numpy(
output_handle.shm_handle,
tritonclient.utils.triton_to_np_dtype(value_pb.datatype),
value_pb.shape, offset=offset
)
offset += output_handle.byte_sizes[i]
# will cause error (exit code 139) without copy
final_output_dict[name] = np.copy(value)
# print(value.dtype, value.shape, value)
# clean up
self.client.unregister_system_shared_memory(input_handle.shm_name)
self.client.unregister_system_shared_memory(output_handle.shm_name)
shm.destroy_shared_memory_region(input_handle.shm_handle)
shm.destroy_shared_memory_region(output_handle.shm_handle)
return final_output_dict
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment