Skip to content

Instantly share code, notes, and snippets.

@FL33TW00D
Created May 6, 2023 09:40
Show Gist options
  • Save FL33TW00D/d81562557279d887705985f7c6ae4481 to your computer and use it in GitHub Desktop.
Save FL33TW00D/d81562557279d887705985f7c6ae4481 to your computer and use it in GitHub Desktop.
import argparse
import onnx
import onnx_graphsurgeon as gs
import concurrent.futures
import numpy as np
from numba import jit
import os
EXPORT_DIR = "."
@jit(nopython=True)
def quantize(matrix):
"""
Quantize a matrix of float32 values to sint8.
Parameters
----------
matrix : numpy.ndarray
The matrix to quantize.
M : int
The number of rows in the matrix.
N : int
The number of columns in the matrix.
Returns
-------
numpy.ndarray
The quantized matrix.
"""
M = matrix.shape[0]
N = matrix.shape[1]
block_size = 4
# [768, 768] -> [768, 192]
quantized_matrix = np.zeros((M, N // block_size), dtype=np.uint32)
absmax = np.max(np.abs(matrix))
# Quantize the matrix values to sint8 and pack them into uint32
for i in range(M):
for j in range(0, N, block_size):
packed_value = (round(matrix[i, j] / absmax * 127) & 0xFF) | \
((round(matrix[i, j + 1] / absmax * 127) & 0xFF) << 8) | \
((round(matrix[i, j + 2] / absmax * 127) & 0xFF) << 16) | \
((round(matrix[i, j + 3] / absmax * 127) & 0xFF) << 24)
quantized_matrix[i, j // block_size] = packed_value
return (quantized_matrix, np.array(absmax, dtype=np.float32))
def quantize_node(node):
if node.op == "MatMul" or node.op == "Gemm":
# Find the constant that is providing the weights
for provider in node.inputs:
if isinstance(provider, gs.Constant):
print("Quantizing: {}".format(provider.name))
(quantized, absmax) = quantize(provider.values)
provider.values = quantized
absmax_const = gs.Constant("{}_absmax".format(provider.name), absmax)
node.inputs.append(absmax_const)
# Change the matmul to a QMatmul
if node.op == "MatMul":
node.op = "QMatMul"
elif node.op == "Gemm":
node.op = "QGemm"
break
def main():
"""
This script takes in an ONNX model, and performs the following quantization
scheme.
1. Find every MatMul / Gemm in the model
2. Determine if one of it's providers is a Constant
3. If so, do the following:
3a. Determine absmax value of the tensor
3b. Pack 4xF32 -> 1xU32 ((Math.round(matrix[i] / absmax * 127) & 0xFF) )
3c. update weights
3d. change dtype to U32
3e. add a new input to the matmul for the absmax value
3f. change the matmul to a QMatmul
"""
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", help="filename of the ONNX model")
parser.add_argument(
"--export-dir",
help="directory to export the modified models to",
default=EXPORT_DIR,
)
args = parser.parse_args()
model = gs.import_onnx(onnx.load(args.model))
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
futures = [executor.submit(quantize_node, node) for node in model.nodes]
concurrent.futures.wait(futures)
model.cleanup().toposort()
gs_export = gs.export_onnx(model)
if not os.path.exists(EXPORT_DIR):
os.makedirs(EXPORT_DIR)
onnx.save(gs_export, "{}/quant_{}".format(EXPORT_DIR, args.model))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment