-
-
Save FL33TW00D/d81562557279d887705985f7c6ae4481 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import onnx | |
import onnx_graphsurgeon as gs | |
import concurrent.futures | |
import numpy as np | |
from numba import jit | |
import os | |
EXPORT_DIR = "." | |
@jit(nopython=True) | |
def quantize(matrix): | |
""" | |
Quantize a matrix of float32 values to sint8. | |
Parameters | |
---------- | |
matrix : numpy.ndarray | |
The matrix to quantize. | |
M : int | |
The number of rows in the matrix. | |
N : int | |
The number of columns in the matrix. | |
Returns | |
------- | |
numpy.ndarray | |
The quantized matrix. | |
""" | |
M = matrix.shape[0] | |
N = matrix.shape[1] | |
block_size = 4 | |
# [768, 768] -> [768, 192] | |
quantized_matrix = np.zeros((M, N // block_size), dtype=np.uint32) | |
absmax = np.max(np.abs(matrix)) | |
# Quantize the matrix values to sint8 and pack them into uint32 | |
for i in range(M): | |
for j in range(0, N, block_size): | |
packed_value = (round(matrix[i, j] / absmax * 127) & 0xFF) | \ | |
((round(matrix[i, j + 1] / absmax * 127) & 0xFF) << 8) | \ | |
((round(matrix[i, j + 2] / absmax * 127) & 0xFF) << 16) | \ | |
((round(matrix[i, j + 3] / absmax * 127) & 0xFF) << 24) | |
quantized_matrix[i, j // block_size] = packed_value | |
return (quantized_matrix, np.array(absmax, dtype=np.float32)) | |
def quantize_node(node): | |
if node.op == "MatMul" or node.op == "Gemm": | |
# Find the constant that is providing the weights | |
for provider in node.inputs: | |
if isinstance(provider, gs.Constant): | |
print("Quantizing: {}".format(provider.name)) | |
(quantized, absmax) = quantize(provider.values) | |
provider.values = quantized | |
absmax_const = gs.Constant("{}_absmax".format(provider.name), absmax) | |
node.inputs.append(absmax_const) | |
# Change the matmul to a QMatmul | |
if node.op == "MatMul": | |
node.op = "QMatMul" | |
elif node.op == "Gemm": | |
node.op = "QGemm" | |
break | |
def main(): | |
""" | |
This script takes in an ONNX model, and performs the following quantization | |
scheme. | |
1. Find every MatMul / Gemm in the model | |
2. Determine if one of it's providers is a Constant | |
3. If so, do the following: | |
3a. Determine absmax value of the tensor | |
3b. Pack 4xF32 -> 1xU32 ((Math.round(matrix[i] / absmax * 127) & 0xFF) ) | |
3c. update weights | |
3d. change dtype to U32 | |
3e. add a new input to the matmul for the absmax value | |
3f. change the matmul to a QMatmul | |
""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-m", "--model", help="filename of the ONNX model") | |
parser.add_argument( | |
"--export-dir", | |
help="directory to export the modified models to", | |
default=EXPORT_DIR, | |
) | |
args = parser.parse_args() | |
model = gs.import_onnx(onnx.load(args.model)) | |
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: | |
futures = [executor.submit(quantize_node, node) for node in model.nodes] | |
concurrent.futures.wait(futures) | |
model.cleanup().toposort() | |
gs_export = gs.export_onnx(model) | |
if not os.path.exists(EXPORT_DIR): | |
os.makedirs(EXPORT_DIR) | |
onnx.save(gs_export, "{}/quant_{}".format(EXPORT_DIR, args.model)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment