Skip to content

Instantly share code, notes, and snippets.

@bddppq
Created September 26, 2017 03:17
Show Gist options
  • Save bddppq/6e9e0667a46bdd0f5d0d9d6392528856 to your computer and use it in GitHub Desktop.
Save bddppq/6e9e0667a46bdd0f5d0d9d6392528856 to your computer and use it in GitHub Desktop.
from argparse import ArgumentParser
import os
from timeit import Timer
import numpy as np
from caffe2.python import workspace
import onnx
from onnx.numpy_helper import to_array
import onnx_caffe2.backend
import torch
import torch.onnx
from torch.autograd import Variable
from torch import nn
import torchvision
def create_graph(model_name, batch_size):
model = getattr(torchvision.models, model_name)(pretrained=False)
model = model.eval()
dummy_input = Variable(torch.rand(batch_size, 3, 224, 224))
torch.onnx.export(model, dummy_input, "tmp.onnx", export_params=True)
graph = onnx.load("tmp.onnx")
os.remove("tmp.onnx")
return graph
def benchmark_caffe2(model_name, batch_size, runs):
graph = create_graph(model_name, batch_size)
dummy_input = np.ndarray((batch_size, 3, 224, 224), dtype=np.float32)
prepared_backend = onnx_caffe2.backend.prepare(graph)
workspace.SwitchWorkspace(prepared_backend.workspace.workspace_id)
workspace.FeedBlob(graph.input[-1], dummy_input)
workspace.CreateNet(prepared_backend.predict_net, False)
workspace.BenchmarkNet(prepared_backend.predict_net.name,
1,
runs,
True)
def run_all_benchmarks(model_name, batch_sizes, runs):
for batch_size in batch_sizes:
print("Batch Size={}".format(batch_size))
caffe2_avg_time = benchmark_caffe2(model_name, batch_size, runs)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--model_name", required=True, type=str, help="torchvision model to run")
parser.add_argument("--batch_sizes", nargs="+", type=int, help="batch sizes to evaluate")
parser.add_argument("--runs", default=10, type=int, help="number of runs per batch size")
opts = parser.parse_args()
run_all_benchmarks(opts.model_name, opts.batch_sizes, opts.runs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment