Skip to content

Instantly share code, notes, and snippets.

@Nanguage
Created January 19, 2024 15:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Nanguage/363c5f8d9dd7b4db6eb9fcb7615554ca to your computer and use it in GitHub Desktop.
Save Nanguage/363c5f8d9dd7b4db6eb9fcb7615554ca to your computer and use it in GitHub Desktop.
"""Test the triton server proxy."""
from PIL import Image
import msgpack
import numpy as np
import requests
import gzip
import json
def get_config(server_url, model_name):
response = requests.get(
f"{server_url}/public/services/triton-client/get_config?model_name="+model_name,
)
return json.loads(response.content)
def encode_data(inputs):
if isinstance(inputs, (np.ndarray, np.generic)):
return {
"_rtype": "ndarray",
"_rvalue": inputs.tobytes(),
"_rshape": inputs.shape,
"_rdtype": str(inputs.dtype),
}
elif isinstance(inputs, (tuple, list)):
ret = []
for input_data in inputs:
ret.append(encode_data(input_data))
return ret
elif isinstance(inputs, dict):
ret = {}
for k in list(inputs.keys()):
ret[k] = encode_data(inputs[k])
return ret
else:
return inputs
def decode_data(outputs):
if isinstance(outputs, dict):
if (
outputs.get("_rtype") == "ndarray"
and outputs["_rdtype"] != "object"
):
return np.frombuffer(
outputs["_rvalue"], dtype=outputs["_rdtype"]
).reshape(outputs["_rshape"])
else:
ret = {}
for k in list(outputs.keys()):
ret[k] = decode_data(outputs[k])
return ret
elif isinstance(outputs, (tuple, list)):
ret = []
for output in outputs:
ret.append(decode_data(output))
return ret
else:
return outputs
def execute(inputs, server_url, model_name, **kwargs):
"""
Execute a model on the trition server.
The supported kwargs are consistent with pyotritonclient
https://github.com/oeway/pyotritonclient/blob/bc655a20fabc4611bbf3c12fb15439c8fc8ee9f5/pyotritonclient/__init__.py#L40-L50
"""
# Represent the numpy array with imjoy_rpc encoding
# See: https://github.com/imjoy-team/imjoy-rpc#data-type-representation
inputs = encode_data(inputs)
kwargs.update(
{
"inputs": inputs,
"model_name": model_name,
}
)
# Encode the arguments as msgpack
data = msgpack.dumps(kwargs)
# Compress the data and send it via a post request to the server
compressed_data = gzip.compress(data)
response = requests.post(
f"{server_url}/public/services/triton-client/execute",
data=compressed_data,
headers={
"Content-Type": "application/msgpack",
"Content-Encoding": "gzip",
},
)
if response.ok:
# Decode the results form the response
results = msgpack.loads(response.content)
# Convert the ndarray objects into numpy arrays
results = decode_data(results)
return results
else:
raise Exception(f"Failed to execute {model_name}: {response.reason or response.text}")
if __name__ == "__main__":
server_url = "http://127.0.0.1:9520"
# Get the model config with information about inputs/outputs etc.
config = get_config(server_url, "efficientsam-encoder")
# print(config)
# Run inference with cellpose-python model
image = np.array(Image.open("tmp/dogs.jpg"))
input_image = image.transpose(2, 0, 1)[None].astype(np.float32) / 255.0
results = execute(
inputs=[input_image],
server_url=server_url,
model_name="efficientsam-encoder",
decode_json=True,
)
embeddings = results["image_embeddings"]
print(embeddings.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment