Last active
June 3, 2020 04:32
-
-
Save phuocphn/393a5420de8d3e274bd09b38c77cb699 to your computer and use it in GitHub Desktop.
Representing an integer q by a K-bit binary encoding as the inner product between a basis vector and the binary coding vector
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
np.random.seed(0) | |
torch.manual_seed(0) | |
nbit = 3 | |
num_levels = 2 ** nbit | |
NORM_PPF_0_75 = 0.6745 | |
basis = torch.tensor([(NORM_PPF_0_75 * 2 / (2 ** nbit - 1)) * (2. ** i) for i in range(nbit)]) | |
# basis = torch.tensor([0.5, 0.125, 0.25, 0.45]) | |
print ("Basis: ", basis) | |
class bcolors: | |
HEADER = '\033[95m' | |
OKBLUE = '\033[94m' | |
OKGREEN = '\033[92m' | |
WARNING = '\033[93m' | |
FAIL = '\033[91m' | |
ENDC = '\033[0m' | |
BOLD = '\033[1m' | |
UNDERLINE = '\033[4m' | |
# initialize level multiplier | |
init_level_multiplier = [] | |
for i in range(0, num_levels): | |
level_multiplier_i = [0. for j in range(nbit)] | |
level_number = i | |
for j in range(nbit): | |
level_multiplier_i[j] = float(level_number % 2) | |
level_number = level_number // 2 | |
init_level_multiplier.append(level_multiplier_i) | |
# initialize threshold multiplier | |
init_thrs_multiplier = [] | |
for i in range(1, num_levels): | |
thrs_multiplier_i = [0. for j in range(num_levels)] | |
thrs_multiplier_i[i - 1] = 0.5 | |
thrs_multiplier_i[i] = 0.5 | |
init_thrs_multiplier.append(thrs_multiplier_i) | |
print ("********* init_level_multiplier ********") | |
print (init_level_multiplier) | |
print ("Number of levles: ", len(init_level_multiplier)) | |
print ("********* init_thrs_multiplier ********") | |
print (init_thrs_multiplier) | |
print ("Number of thrs: ", len(init_thrs_multiplier)) | |
print (bcolors.WARNING + "Calculate levels and sort " + bcolors.ENDC) | |
print ("***********************") | |
level_codes = torch.tensor(init_level_multiplier) | |
# level_values = | |
print ("Level codes: \n", level_codes) | |
print ("Level codes: (shape): ", level_codes.size()) | |
print ("Basic (shape): ", basis.size()) | |
basis = basis.view(nbit,1) | |
level_values = torch.mm(level_codes, basis) | |
print ("Level values: \n", level_values) | |
print ("*******************") | |
print ("Level values (original): \n", level_values) | |
print ("Level values shape (original): ", level_values.size()) | |
print (bcolors.ENDC + "\n") | |
print ("*******************") | |
level_values, level_indices = torch.topk(torch.transpose(level_values, 1, 0), k=num_levels) | |
print ("Ordered Level values: \n") | |
print (level_values) | |
print ("Ordered Level indices: \n") | |
print (level_indices) | |
print ("*******************") | |
level_values = torch.flip(level_values, dims=(-1, )) | |
level_indices = torch.flip(level_indices, dims=(-1, )) | |
print ("Inversed Level values: \n", level_values) | |
print ("Inversed Level indices (compare to level codes): \n", level_indices) | |
print ("*******************") | |
level_values = torch.transpose(level_values, 1, 0) | |
level_indices = torch.transpose(level_indices, 1, 0) | |
print ("Transposed level values: \n", level_values) | |
print ("Transposed level indices: \n", level_indices) | |
print (bcolors.WARNING + "************* " + bcolors.ENDC) | |
print (bcolors.WARNING + "Calculate threshold " + bcolors.ENDC) | |
print ("***********************") | |
thrs_multiplier = torch.tensor(init_thrs_multiplier) | |
thrs = torch.mm(thrs_multiplier, level_values) | |
print ("thrs: \n", thrs) | |
print (bcolors.WARNING + "************* " + bcolors.ENDC) | |
# x = torch.tensor([[0.78, 0.21, 0.66],[0.78, 0.21, 0.66]]) | |
x = torch.randn(13,13,13) * 0.5 + 1 | |
y = torch.zeros_like(x) | |
# zero_dims = torch.stack([x.view(-1).size(0), nbit]) | |
# zero_dims = torch.tensor([[x.view(-1).size(0)], [nbit]]) | |
# reshape_x = x.view(-1) | |
# zero_dims = [x.size(0), x.size(1), nbit] | |
zero_dims = [x.numel(), nbit] | |
bits_y = torch.ones(zero_dims).fill_(0.0) | |
zero_y = torch.zeros_like(x) # bias ????????????????? | |
zero_bits_y = torch.ones(zero_dims).fill_(0.0) | |
print ("x: ", x) | |
print ("y: ", y) | |
print ("zero_dims:" , zero_dims) | |
print ("bits_y: \n", bits_y) | |
for i in range(num_levels-1): | |
g = x > thrs[i] | |
y = torch.where(x > thrs[i], zero_y + level_values[i+1], y) | |
bits_y = torch.where((x > thrs[i]).view(-1,1), zero_bits_y + level_codes[level_indices[i+1]], bits_y) | |
print ("Loop: ",i, " ", g, " ", y, " \n", bits_y) | |
print (bcolors.WARNING + "************* " + bcolors.ENDC) | |
print (bcolors.FAIL + "Restore :" + bcolors.ENDC) | |
print ("Original value: ", x) | |
print ("basis.shape: ", basis.shape) | |
print ("bits_y.shape: ", bits_y.shape) | |
print ("Restore value: ", torch.mm(basis.view(1, -1), bits_y.T)) | |
print ("Quantized y: ", y) | |
loss = torch.nn.MSELoss() | |
print ("quantized_loss:", loss(x, y)) | |
print ("bits_y: \n", bits_y) | |
print (torch.mm(basis.view(1, -1), bits_y.T).view_as(x)) | |
print ("****************************************************") | |
BT = bits_y.T | |
# calculate BTxB | |
print ("*********** calculate BTxB ***************") | |
BTxB = [] | |
for i in range(nbit): | |
for j in range(nbit): | |
BTxBij = BT[i] * BT[j] | |
BTxBij = torch.sum(BTxBij) | |
BTxB.append(BTxBij) | |
BTxB = torch.stack(BTxB).view(nbit, nbit) | |
BTxB_inv = torch.inverse(BTxB) | |
print ("BTxB_inv: \n") | |
print (BTxB_inv) | |
# calculate BTxX | |
print ("*********** calculate BTxX ***************") | |
BTxX = [] | |
for i in range(nbit): | |
BTxXi0 = BT[i] * x.view(-1) | |
BTxXi0 = torch.sum(BTxXi0) | |
BTxX.append(BTxXi0) | |
BTxX = torch.stack(BTxX).view(nbit, 1) | |
print("BTxX: \n", BTxX) | |
# calculate new basis | |
new_basis = torch.mm(BTxB_inv, BTxX) | |
print ("New basis: ", new_basis) | |
print ("Old basis: ", basis) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment