Skip to content

Instantly share code, notes, and snippets.

@aurotripathy
Created February 1, 2023 19:46
Show Gist options
  • Save aurotripathy/0e3edc701ecd8d7dcdf8e88bb2a75edf to your computer and use it in GitHub Desktop.
Save aurotripathy/0e3edc701ecd8d7dcdf8e88bb2a75edf to your computer and use it in GitHub Desktop.
import time
import numpy as np
import torch
import torchvision
from torchvision import transforms, models
import tqdm
import random
import furiosa.runtime.session
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
models_dict = {
'regnet_y_400mf': models.regnet_y_400mf(pretrained=True),
# add mode models
}
total_images = 1000
imagenet = torchvision.datasets.ImageNet("imagenet", split="val", transform=preprocess)
run_outputs = []
for model_name in models_dict:
print(model_name, models_dict[model_name].__class__.__name__)
validation_dataset = torch.utils.data.Subset(imagenet, torch.randperm(len(imagenet))[:1000])
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=1)
quantized_model_name = f'{model_name}_quantized.onnx'
start_time = time.perf_counter_ns()
submitter, queue = furiosa.runtime.session.create_async(quantized_model_name,
worker_num=1,
# Determine how many asynchronous requests you can submit
# without blocking.
input_queue_size=total_images,
output_queue_size=total_images)
correct_predictions, total_predictions = 0, 0
quantized_model_name = f'{model_name}_quantized.onnx'
# submit the inference request async
labels = []
for image, label in tqdm.tqdm(validation_dataloader, desc="Evaluation", unit="images", mininterval=0.5):
image = image.numpy()
idx = random.randint(0, 59999)
labels.append(label)
submitter.submit(image, context=idx)
# receive the results async
for i in range(0, total_images):
context, outputs = queue.recv(100) # 100 is timeout. If None, queue.recv() will be blocking.
prediction = np.argmax(outputs[0].numpy(), axis=1) # postprocessing
if prediction == labels[i].numpy():
correct_predictions += 1
total_predictions += 1
elapsed_time = time.perf_counter_ns() - start_time
print(f'Total samples: {total_predictions}')
print(f'Correct predictions: {correct_predictions} out of {total_predictions}')
latency = elapsed_time / total_predictions
avg_latency = latency / 1_000_000
print(f"Average Latency: {avg_latency:0.3f} ms")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment