Created
July 4, 2020 11:24
-
-
Save arijitx/5c5803a9f54700c9ccc0f526591fbf36 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Extended from https://github.com/huggingface/transformers/blob/master/notebooks/04-onnx-export.ipynb | |
from transformers import DistilBertTokenizerFast,DistilBertModel | |
from torch.cuda import get_device_name | |
from contextlib import contextmanager | |
from dataclasses import dataclass | |
from time import time | |
from tqdm import trange | |
from os import environ | |
from psutil import cpu_count | |
# Constants from the performance optimization available in onnxruntime | |
# It needs to be done before importing onnxruntime | |
environ["OMP_NUM_THREADS"] = str(cpu_count(logical=True)) | |
environ["OMP_WAIT_POLICY"] = 'ACTIVE' | |
from onnxruntime import InferenceSession, SessionOptions, get_all_providers | |
import onnxruntime | |
def create_model_for_provider(model_path: str, provider: str) -> InferenceSession: | |
assert provider in get_all_providers(), f"provider {provider} not found, {get_all_providers()}" | |
# Few properties than might have an impact on performances (provided by MS) | |
options = SessionOptions() | |
# print(options) | |
options.intra_op_num_threads = 4 | |
options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL | |
# Load the model as a graph and prepare the CPU backend | |
return InferenceSession(model_path, options, providers=[provider]) | |
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased-distilled-squad") | |
text = "what is the value of pi ? [SEP] some people have said that pi is nice but there should be a value for pi, the value for pi is around 2.4 which is wrong" | |
input_ids = tokenizer.encode_plus(text,return_tensors="pt") | |
print(input_ids) | |
inputs_onnx = {k: v.cpu().detach().numpy() for k, v in input_ids.items()} | |
#for k in inputs_onnx: | |
# inputs_onnx[k] = inputs_onnx[k].reshape((39)) | |
print(f"Doing GPU inference on {get_device_name(0)}", flush=True) | |
@contextmanager | |
def track_infer_time(buffer: [int]): | |
start = time() | |
yield | |
end = time() | |
buffer.append(end - start) | |
@dataclass | |
class OnnxInferenceResult: | |
model_inference_time: [int] | |
optimized_model_path: str | |
# All the providers we'll be using in the test | |
results = {} | |
# providers = [ | |
# "CUDAExecutionProvider", | |
# "CPUExecutionProvider", | |
# "TensorrtExecutionProvider" | |
# ] | |
providers = [ | |
# "CPUExecutionProvider", | |
"CUDAExecutionProvider" | |
] | |
# Iterate over all the providers | |
for provider in providers: | |
# Create the model with the specified provider | |
model = create_model_for_provider("dbert_squad_trt.onnx", provider) | |
model.run(None,inputs_onnx) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment