Created
June 1, 2020 03:37
-
-
Save phuocphn/5e03e4309ba75cfa927a494b00a51edf to your computer and use it in GitHub Desktop.
(Testcases) Learned Step Size Quantization
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch as t | |
import numpy as np | |
class GradScale(t.nn.Module): | |
def forward(self, x, scale): | |
y = x | |
y_grad = x / scale | |
return (y - y_grad).detach() + y_grad | |
class RoundPass(t.nn.Module): | |
def forward(self, x): | |
y = x.round() | |
y_grad = x | |
return (y - y_grad).detach() + y_grad | |
class Quantize(t.nn.Module): | |
def __init__(self, is_activation, bit): | |
super(Quantize, self).__init__() | |
self.s = t.nn.Parameter(t.ones(1)) | |
self.s.data.fill_(0.125) | |
if is_activation: | |
# unsigned activation is quantized to [0, 2^b-1] | |
self.thd_neg = 0 | |
self.thd_pos = 2 ** bit - 1 | |
else: | |
# signed weight is quantized to [-2^(b-1), 2^(b-1)-1] | |
self.thd_neg = - 2 ** (bit - 1) | |
self.thd_pos = 2 ** (bit - 1) - 1 | |
self.grad_scale = GradScale() | |
self.round_pass = RoundPass() | |
print ("Possible range: [{} ~ {}]".format(self.thd_neg, self.thd_pos)) | |
def update_s(self, new_s): | |
self.s.data.fill_(new_s) | |
def forward(self, x): | |
s_grad_scale = (self.thd_pos * x.numel()) ** 0.5 | |
s_scale = self.grad_scale(self.s, s_grad_scale) | |
# print (f"Scale before: {self.s.item()}, Scale after: {s_scale.item()}") ~~ > self.s not change. | |
# Turn the orginal values to smaller (or larger) scale. | |
x = x / s_scale | |
# ************ forward/before ************* | |
# tensor([-1.0000, -0.9592, -0.9184, -0.8776, -0.8367, -0.7959, -0.7551, -0.7143, | |
# -0.6735, -0.6327, -0.5918, -0.5510, -0.5102, -0.4694, -0.4286, -0.3878, | |
# -0.3469, -0.3061, -0.2653, -0.2245, -0.1837, -0.1429, -0.1020, -0.0612, | |
# -0.0204, 0.0204, 0.0612, 0.1020, 0.1429, 0.1837, 0.2245, 0.2653, | |
# 0.3061, 0.3469, 0.3878, 0.4286, 0.4694, 0.5102, 0.5510, 0.5918, | |
# 0.6327, 0.6735, 0.7143, 0.7551, 0.7959, 0.8367, 0.8776, 0.9184, | |
# 0.9592, 1.0000], dtype=torch.float64) | |
# ************ forward/after ************* | |
# tensor([-8.0000, -7.6735, -7.3469, -7.0204, -6.6939, -6.3673, -6.0408, -5.7143, | |
# -5.3878, -5.0612, -4.7347, -4.4082, -4.0816, -3.7551, -3.4286, -3.1020, | |
# -2.7755, -2.4490, -2.1224, -1.7959, -1.4694, -1.1429, -0.8163, -0.4898, | |
# -0.1633, 0.1633, 0.4898, 0.8163, 1.1429, 1.4694, 1.7959, 2.1224, | |
# 2.4490, 2.7755, 3.1020, 3.4286, 3.7551, 4.0816, 4.4082, 4.7347, | |
# 5.0612, 5.3878, 5.7143, 6.0408, 6.3673, 6.6939, 7.0204, 7.3469, | |
# 7.6735, 8.0000], dtype=torch.float64, grad_fn=<DivBackward0>) | |
# Clamp outside values | |
x = t.clamp(x, self.thd_neg, self.thd_pos) | |
x = self.round_pass(x) | |
x = x * s_scale | |
return x | |
quant = Quantize(is_activation=False, bit=4) | |
quant.update_s(0.125) | |
x=t.tensor(np.linspace(-1.0, 1.0, num=50)) | |
print ("************ Original tensor ************* ") | |
print (x) | |
print ("************ Quantized tensor ************* ") | |
print (quant(x)) | |
print ("Number of unique values: ", t.unique(quant(x)).numel()) | |
# print ("********************************") | |
# for v in np.linspace(-1.0, 1.0, num=50): | |
# x= t.tensor(round(v,2)) | |
# quantized_x = quant(x) | |
# print ("{} ~> {}".format(x, quantized_x)) | |
# print ("********************************") | |
print ("\n\n\n") | |
x=t.tensor(np.linspace(-0.5, 0.5, num=50)) | |
quant.update_s(0.05) | |
print ("************ Original tensor (2) ************* ") | |
print (x) | |
print ("************ Quantized tensor (2) ************* ") | |
print (quant(x)) | |
print ("Number of unique values (2): ", t.unique(quant(x)).numel()) | |
print ("\n\n\n") | |
x=t.tensor(np.linspace(-0.1, 0.1, num=50)) | |
quant.update_s(0.005) | |
print ("************ Original tensor (3) ************* ") | |
print (x) | |
print ("************ Quantized tensor (3) ************* ") | |
print (quant(x)) | |
print ("Number of unique values (3): ", t.unique(quant(x)).numel()) | |
# Results | |
''' | |
(py3-env) phuocphn@phuocphn-Precision-5820-Tower:~/bnn-testcase$ python lsq-quantize.py | |
Possible range: [-8 ~ 7] | |
************ Original tensor ************* | |
tensor([-1.0000, -0.9592, -0.9184, -0.8776, -0.8367, -0.7959, -0.7551, -0.7143, | |
-0.6735, -0.6327, -0.5918, -0.5510, -0.5102, -0.4694, -0.4286, -0.3878, | |
-0.3469, -0.3061, -0.2653, -0.2245, -0.1837, -0.1429, -0.1020, -0.0612, | |
-0.0204, 0.0204, 0.0612, 0.1020, 0.1429, 0.1837, 0.2245, 0.2653, | |
0.3061, 0.3469, 0.3878, 0.4286, 0.4694, 0.5102, 0.5510, 0.5918, | |
0.6327, 0.6735, 0.7143, 0.7551, 0.7959, 0.8367, 0.8776, 0.9184, | |
0.9592, 1.0000], dtype=torch.float64) | |
************ Quantized tensor ************* | |
tensor([-1.0000, -1.0000, -0.8750, -0.8750, -0.8750, -0.7500, -0.7500, -0.7500, | |
-0.6250, -0.6250, -0.6250, -0.5000, -0.5000, -0.5000, -0.3750, -0.3750, | |
-0.3750, -0.2500, -0.2500, -0.2500, -0.1250, -0.1250, -0.1250, 0.0000, | |
0.0000, 0.0000, 0.0000, 0.1250, 0.1250, 0.1250, 0.2500, 0.2500, | |
0.2500, 0.3750, 0.3750, 0.3750, 0.5000, 0.5000, 0.5000, 0.6250, | |
0.6250, 0.6250, 0.7500, 0.7500, 0.7500, 0.8750, 0.8750, 0.8750, | |
0.8750, 0.8750], dtype=torch.float64, grad_fn=<MulBackward0>) | |
Number of unique values: 16 | |
************ Original tensor (2) ************* | |
tensor([-0.5000, -0.4796, -0.4592, -0.4388, -0.4184, -0.3980, -0.3776, -0.3571, | |
-0.3367, -0.3163, -0.2959, -0.2755, -0.2551, -0.2347, -0.2143, -0.1939, | |
-0.1735, -0.1531, -0.1327, -0.1122, -0.0918, -0.0714, -0.0510, -0.0306, | |
-0.0102, 0.0102, 0.0306, 0.0510, 0.0714, 0.0918, 0.1122, 0.1327, | |
0.1531, 0.1735, 0.1939, 0.2143, 0.2347, 0.2551, 0.2755, 0.2959, | |
0.3163, 0.3367, 0.3571, 0.3776, 0.3980, 0.4184, 0.4388, 0.4592, | |
0.4796, 0.5000], dtype=torch.float64) | |
************ Quantized tensor (2) ************* | |
tensor([-0.4000, -0.4000, -0.4000, -0.4000, -0.4000, -0.4000, -0.4000, -0.3500, | |
-0.3500, -0.3000, -0.3000, -0.3000, -0.2500, -0.2500, -0.2000, -0.2000, | |
-0.1500, -0.1500, -0.1500, -0.1000, -0.1000, -0.0500, -0.0500, -0.0500, | |
0.0000, 0.0000, 0.0500, 0.0500, 0.0500, 0.1000, 0.1000, 0.1500, | |
0.1500, 0.1500, 0.2000, 0.2000, 0.2500, 0.2500, 0.3000, 0.3000, | |
0.3000, 0.3500, 0.3500, 0.3500, 0.3500, 0.3500, 0.3500, 0.3500, | |
0.3500, 0.3500], dtype=torch.float64, grad_fn=<MulBackward0>) | |
Number of unique values (2): 16 | |
************ Original tensor (3) ************* | |
tensor([-0.1000, -0.0959, -0.0918, -0.0878, -0.0837, -0.0796, -0.0755, -0.0714, | |
-0.0673, -0.0633, -0.0592, -0.0551, -0.0510, -0.0469, -0.0429, -0.0388, | |
-0.0347, -0.0306, -0.0265, -0.0224, -0.0184, -0.0143, -0.0102, -0.0061, | |
-0.0020, 0.0020, 0.0061, 0.0102, 0.0143, 0.0184, 0.0224, 0.0265, | |
0.0306, 0.0347, 0.0388, 0.0429, 0.0469, 0.0510, 0.0551, 0.0592, | |
0.0633, 0.0673, 0.0714, 0.0755, 0.0796, 0.0837, 0.0878, 0.0918, | |
0.0959, 0.1000], dtype=torch.float64) | |
************ Quantized tensor (3) ************* | |
tensor([-0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, | |
-0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, | |
-0.0350, -0.0300, -0.0250, -0.0200, -0.0200, -0.0150, -0.0100, -0.0050, | |
0.0000, 0.0000, 0.0050, 0.0100, 0.0150, 0.0200, 0.0200, 0.0250, | |
0.0300, 0.0350, 0.0350, 0.0350, 0.0350, 0.0350, 0.0350, 0.0350, | |
0.0350, 0.0350, 0.0350, 0.0350, 0.0350, 0.0350, 0.0350, 0.0350, | |
0.0350, 0.0350], dtype=torch.float64, grad_fn=<MulBackward0>) | |
Number of unique values (3): 16 | |
''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment