Skip to content

Instantly share code, notes, and snippets.

@kimdwkimdw
Last active October 5, 2022 12:21
Show Gist options
  • Save kimdwkimdw/5e290a6ac9c4816e2f03343a3654735e to your computer and use it in GitHub Desktop.
Save kimdwkimdw/5e290a6ac9c4816e2f03343a3654735e to your computer and use it in GitHub Desktop.
import argparse
import os
from glob import glob
import numpy as np
from PIL import Image
from tritony import InferenceClient
def preprocess(img, dtype=np.float32, h=224, w=224, scaling="INCEPTION"):
sample_img = img.convert("RGB")
resized_img = sample_img.resize((w, h), Image.Resampling.BILINEAR)
resized = np.array(resized_img)
if resized.ndim == 2:
resized = resized[:, :, np.newaxis]
scaled = (resized / 127.5) - 1
ordered = np.transpose(scaled, (2, 0, 1))
return ordered.astype(dtype)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--image_folder", type=str, help="Input folder.")
FLAGS = parser.parse_args()
client = InferenceClient.create_with("densenet_onnx", "0.0.0.0:8001", input_dims=3, protocol="grpc")
client.output_kwargs = {"class_count": 1}
image_data = []
for filename in glob(os.path.join(FLAGS.image_folder, "*")):
image_data.append(preprocess(Image.open(filename)))
result = client(np.asarray(image_data))
for output in result:
max_value, arg_max, class_name = output[0].decode("utf-8").split(":")
print(f"{max_value} ({arg_max}) = {class_name}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment