Skip to content

Instantly share code, notes, and snippets.

@fxmarty
Created October 26, 2023 15:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fxmarty/a8edcdb9dffee74790e6bdb42a1bbeaf to your computer and use it in GitHub Desktop.
Save fxmarty/a8edcdb9dffee74790e6bdb42a1bbeaf to your computer and use it in GitHub Desktop.
bench donut ort
from transformers import AutoTokenizer
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
import time
from optimum.onnxruntime import ORTModelForVision2Seq
from PIL import Image
import onnxruntime as onnxrt
device = "cpu"
if device == "cpu":
provider = "CPUExecutionProvider"
else:
provider = "CUDAExecutionProvider"
hf_model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2").to(device)
onnx_dir = "/home/fxmarty/hf_internship/optimum/donut_onnx"
onnx_model = ORTModelForVision2Seq.from_pretrained(onnx_dir, use_io_binding=True, provider=provider)
processor = DonutProcessor.from_pretrained(onnx_dir)
tokenizer = AutoTokenizer.from_pretrained(onnx_dir)
image = Image.open('/home/fxmarty/Downloads/example_donut.jpg')
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(
task_prompt, add_special_tokens=False, return_tensors="pt"
).input_ids
decoder_input_ids = decoder_input_ids.to(device)
pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
# warmup
outputs = hf_model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_new_tokens=40,
use_cache=True,
num_beams=1,
)
n_runs = 10
start = time.time()
for i in range(n_runs):
outputs = hf_model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_new_tokens=40,
use_cache=True,
num_beams=1,
)
end = time.time()
torch_output = processor.batch_decode(outputs)
took = (end - start) / n_runs * 1e3
print(f"Time torch: {took:.3f} ms {device}")
# warmup
generated_ids = onnx_model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_new_tokens=40,
use_cache=True,
num_beams=1,
)
start = time.time()
for i in range(n_runs):
generated_ids = onnx_model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_new_tokens=40,
use_cache=True,
num_beams=1,
)
end = time.time()
ort_output = processor.batch_decode(outputs)
took = (end - start) / n_runs * 1e3
print(f'Time ONNX: {took:.3f} s {device}')
print("PyTorch & ORT match:", ort_output == torch_output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment