Created
September 26, 2017 03:17
-
-
Save bddppq/6e9e0667a46bdd0f5d0d9d6392528856 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from argparse import ArgumentParser | |
import os | |
from timeit import Timer | |
import numpy as np | |
from caffe2.python import workspace | |
import onnx | |
from onnx.numpy_helper import to_array | |
import onnx_caffe2.backend | |
import torch | |
import torch.onnx | |
from torch.autograd import Variable | |
from torch import nn | |
import torchvision | |
def create_graph(model_name, batch_size): | |
model = getattr(torchvision.models, model_name)(pretrained=False) | |
model = model.eval() | |
dummy_input = Variable(torch.rand(batch_size, 3, 224, 224)) | |
torch.onnx.export(model, dummy_input, "tmp.onnx", export_params=True) | |
graph = onnx.load("tmp.onnx") | |
os.remove("tmp.onnx") | |
return graph | |
def benchmark_caffe2(model_name, batch_size, runs): | |
graph = create_graph(model_name, batch_size) | |
dummy_input = np.ndarray((batch_size, 3, 224, 224), dtype=np.float32) | |
prepared_backend = onnx_caffe2.backend.prepare(graph) | |
workspace.SwitchWorkspace(prepared_backend.workspace.workspace_id) | |
workspace.FeedBlob(graph.input[-1], dummy_input) | |
workspace.CreateNet(prepared_backend.predict_net, False) | |
workspace.BenchmarkNet(prepared_backend.predict_net.name, | |
1, | |
runs, | |
True) | |
def run_all_benchmarks(model_name, batch_sizes, runs): | |
for batch_size in batch_sizes: | |
print("Batch Size={}".format(batch_size)) | |
caffe2_avg_time = benchmark_caffe2(model_name, batch_size, runs) | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("--model_name", required=True, type=str, help="torchvision model to run") | |
parser.add_argument("--batch_sizes", nargs="+", type=int, help="batch sizes to evaluate") | |
parser.add_argument("--runs", default=10, type=int, help="number of runs per batch size") | |
opts = parser.parse_args() | |
run_all_benchmarks(opts.model_name, opts.batch_sizes, opts.runs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment