Last active
June 14, 2024 13:29
-
-
Save github-louis-fruleux/fca27dbbc1a30a030fae31211bc4b684 to your computer and use it in GitHub Desktop.
String upper slower in onnx than plain python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import timeit | |
import numpy as np | |
def setup_code(model_name: str): | |
return f""" | |
import string | |
import random | |
import onnxruntime as _ort | |
from onnxruntime_extensions import get_library_path as _lib_path | |
import numpy as np | |
import tensorflow as tf | |
from keras.layers import BatchNormalization, Identity, Layer | |
import onnx | |
from onnx import helper | |
import tf2onnx | |
def generate_random_string(string_len): | |
return ''.join(random.choice(string.ascii_lowercase) for _ in range(string_len)) | |
def generate_input(num_samples, num_features, string_len): | |
return [[generate_random_string(string_len) for _ in range(num_features)] for _ in range(num_samples)] | |
so = _ort.SessionOptions() | |
so.register_custom_ops_library(_lib_path()) | |
so.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
session = _ort.InferenceSession("{model_name}.onnx", so) | |
input_data = {{ | |
'string_input': generate_input(1, 10, 8) | |
}} | |
""" | |
# Setup code | |
test_code = """ | |
input_data["string_input"] = [[string.upper() for string in l] for l in input_data["string_input"]] | |
session.run(None, input_data) | |
""" | |
# Timeit statement | |
times = timeit.repeat(setup=setup_code("onnx_without_upper"), stmt=test_code, repeat=10, number=1000) | |
print(f"Execution time: {np.mean(times) * 1000} +/- {3 * np.std(times) * 1000} us") | |
test_code = """ | |
session.run(None, input_data) | |
""" | |
# Timeit statement | |
times = timeit.repeat(setup=setup_code("onnx_with_upper"), stmt=test_code, repeat=10, number=1000) | |
print(f"Execution time: {np.mean(times) * 1000} +/- {3 * np.std(times) * 1000} us") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tensorflow==2.15.0 | |
onnx | |
onnx_extensions |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment