Last active
June 25, 2024 17:28
-
-
Save mcollinswisc/d1cd9d13b4e5fbad01c75dca5c9ca576 to your computer and use it in GitHub Desktop.
qdq_flatten_optim.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import enum | |
import numpy as np | |
import onnx | |
import onnxruntime | |
_ORIG_MODEL_PATH = "orig_model.onnx" | |
_ORT_OPTIM_LEVEL = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL | |
_OPTIM_MODEL_PATH = "optim_model.onnx" | |
_ORT_ENABLE_PROFILING = True | |
class MiddleOp(enum.Enum): | |
FLATTEN = 1 | |
RESHAPE = 2 | |
_TEST_OP = MiddleOp.FLATTEN | |
def _make_test_model(): | |
input_value = onnx.helper.make_tensor_value_info( | |
name="input", | |
elem_type=onnx.TensorProto.UINT8, | |
shape=[16, 16], | |
) | |
scale = onnx.numpy_helper.from_array( | |
np.array(0.045887697488069534, dtype=np.float32), | |
name="scale", | |
) | |
zero_point = onnx.numpy_helper.from_array( | |
np.zeros([], dtype=np.uint8), | |
name="zero_point", | |
) | |
initializer = [scale, zero_point] | |
dq = onnx.helper.make_node( | |
op_type="DequantizeLinear", | |
inputs=["input", "scale", "zero_point"], | |
outputs=["nonflat_float"], | |
) | |
if _TEST_OP == MiddleOp.FLATTEN: | |
middle_op = onnx.helper.make_node( | |
op_type="Flatten", | |
inputs=["nonflat_float"], | |
outputs=["flat_float"], | |
axis=0, | |
) | |
elif _TEST_OP == MiddleOp.RESHAPE: | |
shape = onnx.numpy_helper.from_array( | |
np.array([1, 256], dtype=np.int64), | |
name="shape", | |
) | |
initializer.append(shape) | |
middle_op = onnx.helper.make_node( | |
op_type="Reshape", | |
inputs=["nonflat_float", "shape"], | |
outputs=["flat_float"], | |
) | |
else: | |
raise ValueError("Unknown _TEST_OP") | |
q = onnx.helper.make_node( | |
op_type="QuantizeLinear", | |
inputs=["flat_float", "scale", "zero_point"], | |
outputs=["output"], | |
) | |
output_value = onnx.helper.make_tensor_value_info( | |
name="output", | |
elem_type=onnx.TensorProto.UINT8, | |
shape=[1, 256], | |
) | |
graph = onnx.helper.make_graph( | |
nodes=[dq, middle_op, q], | |
name="test_graph", | |
inputs=[input_value], | |
outputs=[output_value], | |
initializer=initializer, | |
) | |
model = onnx.helper.make_model(graph) | |
onnx.checker.check_model(model) | |
return model | |
def _make_sess(model): | |
sess_opts = onnxruntime.SessionOptions() | |
sess_opts.graph_optimization_level = _ORT_OPTIM_LEVEL | |
sess_opts.optimized_model_filepath = _OPTIM_MODEL_PATH | |
sess_opts.enable_profiling = _ORT_ENABLE_PROFILING | |
return onnxruntime.InferenceSession(_ORIG_MODEL_PATH, sess_opts) | |
model = _make_test_model() | |
onnx.save(model, _ORIG_MODEL_PATH) | |
sess = _make_sess(model) | |
in_array = np.reshape(np.arange(256, dtype=np.uint8), [16, 16]) | |
in_dict = {"input": in_array} | |
out = sess.run(["output"], in_dict)[0] | |
assert out.dtype == np.uint8 | |
assert out.shape == (1, 256) | |
expected_out = np.expand_dims(np.arange(256, dtype=np.uint8), 0) | |
np.testing.assert_array_equal(out, expected_out) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment