Skip to content

Instantly share code, notes, and snippets.

@phuocphn
Created June 1, 2020 03:37
Show Gist options
  • Save phuocphn/5e03e4309ba75cfa927a494b00a51edf to your computer and use it in GitHub Desktop.
Save phuocphn/5e03e4309ba75cfa927a494b00a51edf to your computer and use it in GitHub Desktop.
(Testcases) Learned Step Size Quantization
import torch as t
import numpy as np
class GradScale(t.nn.Module):
def forward(self, x, scale):
y = x
y_grad = x / scale
return (y - y_grad).detach() + y_grad
class RoundPass(t.nn.Module):
def forward(self, x):
y = x.round()
y_grad = x
return (y - y_grad).detach() + y_grad
class Quantize(t.nn.Module):
def __init__(self, is_activation, bit):
super(Quantize, self).__init__()
self.s = t.nn.Parameter(t.ones(1))
self.s.data.fill_(0.125)
if is_activation:
# unsigned activation is quantized to [0, 2^b-1]
self.thd_neg = 0
self.thd_pos = 2 ** bit - 1
else:
# signed weight is quantized to [-2^(b-1), 2^(b-1)-1]
self.thd_neg = - 2 ** (bit - 1)
self.thd_pos = 2 ** (bit - 1) - 1
self.grad_scale = GradScale()
self.round_pass = RoundPass()
print ("Possible range: [{} ~ {}]".format(self.thd_neg, self.thd_pos))
def update_s(self, new_s):
self.s.data.fill_(new_s)
def forward(self, x):
s_grad_scale = (self.thd_pos * x.numel()) ** 0.5
s_scale = self.grad_scale(self.s, s_grad_scale)
# print (f"Scale before: {self.s.item()}, Scale after: {s_scale.item()}") ~~ > self.s not change.
# Turn the orginal values to smaller (or larger) scale.
x = x / s_scale
# ************ forward/before *************
# tensor([-1.0000, -0.9592, -0.9184, -0.8776, -0.8367, -0.7959, -0.7551, -0.7143,
# -0.6735, -0.6327, -0.5918, -0.5510, -0.5102, -0.4694, -0.4286, -0.3878,
# -0.3469, -0.3061, -0.2653, -0.2245, -0.1837, -0.1429, -0.1020, -0.0612,
# -0.0204, 0.0204, 0.0612, 0.1020, 0.1429, 0.1837, 0.2245, 0.2653,
# 0.3061, 0.3469, 0.3878, 0.4286, 0.4694, 0.5102, 0.5510, 0.5918,
# 0.6327, 0.6735, 0.7143, 0.7551, 0.7959, 0.8367, 0.8776, 0.9184,
# 0.9592, 1.0000], dtype=torch.float64)
# ************ forward/after *************
# tensor([-8.0000, -7.6735, -7.3469, -7.0204, -6.6939, -6.3673, -6.0408, -5.7143,
# -5.3878, -5.0612, -4.7347, -4.4082, -4.0816, -3.7551, -3.4286, -3.1020,
# -2.7755, -2.4490, -2.1224, -1.7959, -1.4694, -1.1429, -0.8163, -0.4898,
# -0.1633, 0.1633, 0.4898, 0.8163, 1.1429, 1.4694, 1.7959, 2.1224,
# 2.4490, 2.7755, 3.1020, 3.4286, 3.7551, 4.0816, 4.4082, 4.7347,
# 5.0612, 5.3878, 5.7143, 6.0408, 6.3673, 6.6939, 7.0204, 7.3469,
# 7.6735, 8.0000], dtype=torch.float64, grad_fn=<DivBackward0>)
# Clamp outside values
x = t.clamp(x, self.thd_neg, self.thd_pos)
x = self.round_pass(x)
x = x * s_scale
return x
quant = Quantize(is_activation=False, bit=4)
quant.update_s(0.125)
x=t.tensor(np.linspace(-1.0, 1.0, num=50))
print ("************ Original tensor ************* ")
print (x)
print ("************ Quantized tensor ************* ")
print (quant(x))
print ("Number of unique values: ", t.unique(quant(x)).numel())
# print ("********************************")
# for v in np.linspace(-1.0, 1.0, num=50):
# x= t.tensor(round(v,2))
# quantized_x = quant(x)
# print ("{} ~> {}".format(x, quantized_x))
# print ("********************************")
print ("\n\n\n")
x=t.tensor(np.linspace(-0.5, 0.5, num=50))
quant.update_s(0.05)
print ("************ Original tensor (2) ************* ")
print (x)
print ("************ Quantized tensor (2) ************* ")
print (quant(x))
print ("Number of unique values (2): ", t.unique(quant(x)).numel())
print ("\n\n\n")
x=t.tensor(np.linspace(-0.1, 0.1, num=50))
quant.update_s(0.005)
print ("************ Original tensor (3) ************* ")
print (x)
print ("************ Quantized tensor (3) ************* ")
print (quant(x))
print ("Number of unique values (3): ", t.unique(quant(x)).numel())
# Results
'''
(py3-env) phuocphn@phuocphn-Precision-5820-Tower:~/bnn-testcase$ python lsq-quantize.py
Possible range: [-8 ~ 7]
************ Original tensor *************
tensor([-1.0000, -0.9592, -0.9184, -0.8776, -0.8367, -0.7959, -0.7551, -0.7143,
-0.6735, -0.6327, -0.5918, -0.5510, -0.5102, -0.4694, -0.4286, -0.3878,
-0.3469, -0.3061, -0.2653, -0.2245, -0.1837, -0.1429, -0.1020, -0.0612,
-0.0204, 0.0204, 0.0612, 0.1020, 0.1429, 0.1837, 0.2245, 0.2653,
0.3061, 0.3469, 0.3878, 0.4286, 0.4694, 0.5102, 0.5510, 0.5918,
0.6327, 0.6735, 0.7143, 0.7551, 0.7959, 0.8367, 0.8776, 0.9184,
0.9592, 1.0000], dtype=torch.float64)
************ Quantized tensor *************
tensor([-1.0000, -1.0000, -0.8750, -0.8750, -0.8750, -0.7500, -0.7500, -0.7500,
-0.6250, -0.6250, -0.6250, -0.5000, -0.5000, -0.5000, -0.3750, -0.3750,
-0.3750, -0.2500, -0.2500, -0.2500, -0.1250, -0.1250, -0.1250, 0.0000,
0.0000, 0.0000, 0.0000, 0.1250, 0.1250, 0.1250, 0.2500, 0.2500,
0.2500, 0.3750, 0.3750, 0.3750, 0.5000, 0.5000, 0.5000, 0.6250,
0.6250, 0.6250, 0.7500, 0.7500, 0.7500, 0.8750, 0.8750, 0.8750,
0.8750, 0.8750], dtype=torch.float64, grad_fn=<MulBackward0>)
Number of unique values: 16
************ Original tensor (2) *************
tensor([-0.5000, -0.4796, -0.4592, -0.4388, -0.4184, -0.3980, -0.3776, -0.3571,
-0.3367, -0.3163, -0.2959, -0.2755, -0.2551, -0.2347, -0.2143, -0.1939,
-0.1735, -0.1531, -0.1327, -0.1122, -0.0918, -0.0714, -0.0510, -0.0306,
-0.0102, 0.0102, 0.0306, 0.0510, 0.0714, 0.0918, 0.1122, 0.1327,
0.1531, 0.1735, 0.1939, 0.2143, 0.2347, 0.2551, 0.2755, 0.2959,
0.3163, 0.3367, 0.3571, 0.3776, 0.3980, 0.4184, 0.4388, 0.4592,
0.4796, 0.5000], dtype=torch.float64)
************ Quantized tensor (2) *************
tensor([-0.4000, -0.4000, -0.4000, -0.4000, -0.4000, -0.4000, -0.4000, -0.3500,
-0.3500, -0.3000, -0.3000, -0.3000, -0.2500, -0.2500, -0.2000, -0.2000,
-0.1500, -0.1500, -0.1500, -0.1000, -0.1000, -0.0500, -0.0500, -0.0500,
0.0000, 0.0000, 0.0500, 0.0500, 0.0500, 0.1000, 0.1000, 0.1500,
0.1500, 0.1500, 0.2000, 0.2000, 0.2500, 0.2500, 0.3000, 0.3000,
0.3000, 0.3500, 0.3500, 0.3500, 0.3500, 0.3500, 0.3500, 0.3500,
0.3500, 0.3500], dtype=torch.float64, grad_fn=<MulBackward0>)
Number of unique values (2): 16
************ Original tensor (3) *************
tensor([-0.1000, -0.0959, -0.0918, -0.0878, -0.0837, -0.0796, -0.0755, -0.0714,
-0.0673, -0.0633, -0.0592, -0.0551, -0.0510, -0.0469, -0.0429, -0.0388,
-0.0347, -0.0306, -0.0265, -0.0224, -0.0184, -0.0143, -0.0102, -0.0061,
-0.0020, 0.0020, 0.0061, 0.0102, 0.0143, 0.0184, 0.0224, 0.0265,
0.0306, 0.0347, 0.0388, 0.0429, 0.0469, 0.0510, 0.0551, 0.0592,
0.0633, 0.0673, 0.0714, 0.0755, 0.0796, 0.0837, 0.0878, 0.0918,
0.0959, 0.1000], dtype=torch.float64)
************ Quantized tensor (3) *************
tensor([-0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400,
-0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400,
-0.0350, -0.0300, -0.0250, -0.0200, -0.0200, -0.0150, -0.0100, -0.0050,
0.0000, 0.0000, 0.0050, 0.0100, 0.0150, 0.0200, 0.0200, 0.0250,
0.0300, 0.0350, 0.0350, 0.0350, 0.0350, 0.0350, 0.0350, 0.0350,
0.0350, 0.0350, 0.0350, 0.0350, 0.0350, 0.0350, 0.0350, 0.0350,
0.0350, 0.0350], dtype=torch.float64, grad_fn=<MulBackward0>)
Number of unique values (3): 16
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment