Skip to content

Instantly share code, notes, and snippets.

@arijitx
Created July 4, 2020 11:24
Show Gist options
  • Save arijitx/5c5803a9f54700c9ccc0f526591fbf36 to your computer and use it in GitHub Desktop.
Save arijitx/5c5803a9f54700c9ccc0f526591fbf36 to your computer and use it in GitHub Desktop.
## Extended from https://github.com/huggingface/transformers/blob/master/notebooks/04-onnx-export.ipynb
from transformers import DistilBertTokenizerFast,DistilBertModel
from torch.cuda import get_device_name
from contextlib import contextmanager
from dataclasses import dataclass
from time import time
from tqdm import trange
from os import environ
from psutil import cpu_count
# Constants from the performance optimization available in onnxruntime
# It needs to be done before importing onnxruntime
environ["OMP_NUM_THREADS"] = str(cpu_count(logical=True))
environ["OMP_WAIT_POLICY"] = 'ACTIVE'
from onnxruntime import InferenceSession, SessionOptions, get_all_providers
import onnxruntime
def create_model_for_provider(model_path: str, provider: str) -> InferenceSession:
assert provider in get_all_providers(), f"provider {provider} not found, {get_all_providers()}"
# Few properties than might have an impact on performances (provided by MS)
options = SessionOptions()
# print(options)
options.intra_op_num_threads = 4
options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
# Load the model as a graph and prepare the CPU backend
return InferenceSession(model_path, options, providers=[provider])
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased-distilled-squad")
text = "what is the value of pi ? [SEP] some people have said that pi is nice but there should be a value for pi, the value for pi is around 2.4 which is wrong"
input_ids = tokenizer.encode_plus(text,return_tensors="pt")
print(input_ids)
inputs_onnx = {k: v.cpu().detach().numpy() for k, v in input_ids.items()}
#for k in inputs_onnx:
# inputs_onnx[k] = inputs_onnx[k].reshape((39))
print(f"Doing GPU inference on {get_device_name(0)}", flush=True)
@contextmanager
def track_infer_time(buffer: [int]):
start = time()
yield
end = time()
buffer.append(end - start)
@dataclass
class OnnxInferenceResult:
model_inference_time: [int]
optimized_model_path: str
# All the providers we'll be using in the test
results = {}
# providers = [
# "CUDAExecutionProvider",
# "CPUExecutionProvider",
# "TensorrtExecutionProvider"
# ]
providers = [
# "CPUExecutionProvider",
"CUDAExecutionProvider"
]
# Iterate over all the providers
for provider in providers:
# Create the model with the specified provider
model = create_model_for_provider("dbert_squad_trt.onnx", provider)
model.run(None,inputs_onnx)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment