Skip to content

Instantly share code, notes, and snippets.

@jfrery
Last active May 9, 2022 07:44
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jfrery/758c85c4e98998e3a6315168bdd83f44 to your computer and use it in GitHub Desktop.
Save jfrery/758c85c4e98998e3a6315168bdd83f44 to your computer and use it in GitHub Desktop.
quantize mnist model
# Quantization
# Specify a number of bits for the quantization and chose whether we want signed or unsigned quantization.
n_bits = 6
is_signed = False
# Quantization from concrete
from concrete.ml.quantization import PostTrainingAffineQuantization
# Quantized Array from concrete
from concrete.ml.quantization import QuantizedArray
# Quantize our model with
pt_quant = PostTrainingAffineQuantization(n_bits = n_bits, numpy_model = numpy_fc_model, is_signed = is_signed)
# Calibrate layers and activations
quant_module = pt_quant.quantize_module(mnist_test_data)
# Quantize input
q_mnist_test_data = QuantizedArray(n_bits = n_bits, values=mnist_test_data, is_signed=is_signed)
# Get the position of different value (MNIST has a lot of black pixels)
arg_diff_values = (mnist_test_data != -0.42421296)
# Compare dequantized input value vs real input values
# Real input
mnist_test_data[arg_diff_values][:16]
# Output: array([0.64495873, 1.9305104 , 1.5995764 , 1.4977505 , 0.33948106,
# 0.03400347, 2.401455 , 2.8087585 , 2.8087585 , 2.8087585 ,
# 2.8087585 , 2.6432915 , 2.0959773 , 2.0959773 , 2.0959773 ,
# 2.0959773 ],)
# Dequantized input
q_mnist_test_data.dequant()[arg_diff_values][:16]
# Output: array([0.66974755, 1.90620457, 1.59709032, 1.49405223, 0.3606333 ,
# 0.05151904, 2.42139499, 2.83354733, 2.83354733, 2.83354733,
# 2.83354733, 2.62747116, 2.11228074, 2.11228074, 2.11228074,
# 2.11228074])
# Check the quantized input values
q_mnist_test_data.qvalues[arg_diff_values][:16]
# Output: array([21, 45, 39, 37, 15, 9, 55, 63, 63, 63, 63, 59, 49, 49, 49, 49])
# Check the quantized weights for the first layer
next(iter(quant_module.quant_layers_dict.values()))[1].constant_inputs[1].qvalues[0][:16]
# Output: array([32, 49, 6, 9, 20, 40, 31, 57, 29, 40, 22, 26, 2, 11, 19, 33])
# Make sure all values are integers with 2**6 (64) values
np.unique(q_mnist_test_data.qvalues[arg_diff_values])
# Output: array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
# 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
# 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
# 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
# Accuracy Quantized Numpy (6 bits)
(quant_module.forward_and_dequant(q_mnist_test_data.qvalues).argmax(1) == mnist_test_target).mean()
# Output: 0.9726
# Compute the drop in accuracy due to the quantization of the floating point value model with 6 bits of precision.
np.round(100*np.abs((numpy_fc_model(mnist_test_data).argmax(1) == mnist_test_target).mean() - (quant_module.forward_and_dequant(q_mnist_test_data.qvalues).argmax(1) == mnist_test_target).mean()),2)
# Output: 0.07%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment