-
-
Save jfrery/758c85c4e98998e3a6315168bdd83f44 to your computer and use it in GitHub Desktop.
quantize mnist model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Quantization | |
# Specify a number of bits for the quantization and chose whether we want signed or unsigned quantization. | |
n_bits = 6 | |
is_signed = False | |
# Quantization from concrete | |
from concrete.ml.quantization import PostTrainingAffineQuantization | |
# Quantized Array from concrete | |
from concrete.ml.quantization import QuantizedArray | |
# Quantize our model with | |
pt_quant = PostTrainingAffineQuantization(n_bits = n_bits, numpy_model = numpy_fc_model, is_signed = is_signed) | |
# Calibrate layers and activations | |
quant_module = pt_quant.quantize_module(mnist_test_data) | |
# Quantize input | |
q_mnist_test_data = QuantizedArray(n_bits = n_bits, values=mnist_test_data, is_signed=is_signed) | |
# Get the position of different value (MNIST has a lot of black pixels) | |
arg_diff_values = (mnist_test_data != -0.42421296) | |
# Compare dequantized input value vs real input values | |
# Real input | |
mnist_test_data[arg_diff_values][:16] | |
# Output: array([0.64495873, 1.9305104 , 1.5995764 , 1.4977505 , 0.33948106, | |
# 0.03400347, 2.401455 , 2.8087585 , 2.8087585 , 2.8087585 , | |
# 2.8087585 , 2.6432915 , 2.0959773 , 2.0959773 , 2.0959773 , | |
# 2.0959773 ],) | |
# Dequantized input | |
q_mnist_test_data.dequant()[arg_diff_values][:16] | |
# Output: array([0.66974755, 1.90620457, 1.59709032, 1.49405223, 0.3606333 , | |
# 0.05151904, 2.42139499, 2.83354733, 2.83354733, 2.83354733, | |
# 2.83354733, 2.62747116, 2.11228074, 2.11228074, 2.11228074, | |
# 2.11228074]) | |
# Check the quantized input values | |
q_mnist_test_data.qvalues[arg_diff_values][:16] | |
# Output: array([21, 45, 39, 37, 15, 9, 55, 63, 63, 63, 63, 59, 49, 49, 49, 49]) | |
# Check the quantized weights for the first layer | |
next(iter(quant_module.quant_layers_dict.values()))[1].constant_inputs[1].qvalues[0][:16] | |
# Output: array([32, 49, 6, 9, 20, 40, 31, 57, 29, 40, 22, 26, 2, 11, 19, 33]) | |
# Make sure all values are integers with 2**6 (64) values | |
np.unique(q_mnist_test_data.qvalues[arg_diff_values]) | |
# Output: array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, | |
# 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, | |
# 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, | |
# 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]) | |
# Accuracy Quantized Numpy (6 bits) | |
(quant_module.forward_and_dequant(q_mnist_test_data.qvalues).argmax(1) == mnist_test_target).mean() | |
# Output: 0.9726 | |
# Compute the drop in accuracy due to the quantization of the floating point value model with 6 bits of precision. | |
np.round(100*np.abs((numpy_fc_model(mnist_test_data).argmax(1) == mnist_test_target).mean() - (quant_module.forward_and_dequant(q_mnist_test_data.qvalues).argmax(1) == mnist_test_target).mean()),2) | |
# Output: 0.07% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment