Skip to content

Instantly share code, notes, and snippets.

@mcollinswisc
Last active June 25, 2024 17:28
Show Gist options
  • Save mcollinswisc/d1cd9d13b4e5fbad01c75dca5c9ca576 to your computer and use it in GitHub Desktop.
Save mcollinswisc/d1cd9d13b4e5fbad01c75dca5c9ca576 to your computer and use it in GitHub Desktop.
qdq_flatten_optim.py
import enum
import numpy as np
import onnx
import onnxruntime
_ORIG_MODEL_PATH = "orig_model.onnx"
_ORT_OPTIM_LEVEL = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
_OPTIM_MODEL_PATH = "optim_model.onnx"
_ORT_ENABLE_PROFILING = True
class MiddleOp(enum.Enum):
FLATTEN = 1
RESHAPE = 2
_TEST_OP = MiddleOp.FLATTEN
def _make_test_model():
input_value = onnx.helper.make_tensor_value_info(
name="input",
elem_type=onnx.TensorProto.UINT8,
shape=[16, 16],
)
scale = onnx.numpy_helper.from_array(
np.array(0.045887697488069534, dtype=np.float32),
name="scale",
)
zero_point = onnx.numpy_helper.from_array(
np.zeros([], dtype=np.uint8),
name="zero_point",
)
initializer = [scale, zero_point]
dq = onnx.helper.make_node(
op_type="DequantizeLinear",
inputs=["input", "scale", "zero_point"],
outputs=["nonflat_float"],
)
if _TEST_OP == MiddleOp.FLATTEN:
middle_op = onnx.helper.make_node(
op_type="Flatten",
inputs=["nonflat_float"],
outputs=["flat_float"],
axis=0,
)
elif _TEST_OP == MiddleOp.RESHAPE:
shape = onnx.numpy_helper.from_array(
np.array([1, 256], dtype=np.int64),
name="shape",
)
initializer.append(shape)
middle_op = onnx.helper.make_node(
op_type="Reshape",
inputs=["nonflat_float", "shape"],
outputs=["flat_float"],
)
else:
raise ValueError("Unknown _TEST_OP")
q = onnx.helper.make_node(
op_type="QuantizeLinear",
inputs=["flat_float", "scale", "zero_point"],
outputs=["output"],
)
output_value = onnx.helper.make_tensor_value_info(
name="output",
elem_type=onnx.TensorProto.UINT8,
shape=[1, 256],
)
graph = onnx.helper.make_graph(
nodes=[dq, middle_op, q],
name="test_graph",
inputs=[input_value],
outputs=[output_value],
initializer=initializer,
)
model = onnx.helper.make_model(graph)
onnx.checker.check_model(model)
return model
def _make_sess(model):
sess_opts = onnxruntime.SessionOptions()
sess_opts.graph_optimization_level = _ORT_OPTIM_LEVEL
sess_opts.optimized_model_filepath = _OPTIM_MODEL_PATH
sess_opts.enable_profiling = _ORT_ENABLE_PROFILING
return onnxruntime.InferenceSession(_ORIG_MODEL_PATH, sess_opts)
model = _make_test_model()
onnx.save(model, _ORIG_MODEL_PATH)
sess = _make_sess(model)
in_array = np.reshape(np.arange(256, dtype=np.uint8), [16, 16])
in_dict = {"input": in_array}
out = sess.run(["output"], in_dict)[0]
assert out.dtype == np.uint8
assert out.shape == (1, 256)
expected_out = np.expand_dims(np.arange(256, dtype=np.uint8), 0)
np.testing.assert_array_equal(out, expected_out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment