Skip to content

Instantly share code, notes, and snippets.

@phuocphn
Last active June 3, 2020 04:32
Show Gist options
  • Save phuocphn/393a5420de8d3e274bd09b38c77cb699 to your computer and use it in GitHub Desktop.
Save phuocphn/393a5420de8d3e274bd09b38c77cb699 to your computer and use it in GitHub Desktop.
Representing an integer q by a K-bit binary encoding as the inner product between a basis vector and the binary coding vector
import torch
import numpy as np
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(0)
torch.manual_seed(0)
nbit = 3
num_levels = 2 ** nbit
NORM_PPF_0_75 = 0.6745
basis = torch.tensor([(NORM_PPF_0_75 * 2 / (2 ** nbit - 1)) * (2. ** i) for i in range(nbit)])
# basis = torch.tensor([0.5, 0.125, 0.25, 0.45])
print ("Basis: ", basis)
class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
# initialize level multiplier
init_level_multiplier = []
for i in range(0, num_levels):
level_multiplier_i = [0. for j in range(nbit)]
level_number = i
for j in range(nbit):
level_multiplier_i[j] = float(level_number % 2)
level_number = level_number // 2
init_level_multiplier.append(level_multiplier_i)
# initialize threshold multiplier
init_thrs_multiplier = []
for i in range(1, num_levels):
thrs_multiplier_i = [0. for j in range(num_levels)]
thrs_multiplier_i[i - 1] = 0.5
thrs_multiplier_i[i] = 0.5
init_thrs_multiplier.append(thrs_multiplier_i)
print ("********* init_level_multiplier ********")
print (init_level_multiplier)
print ("Number of levles: ", len(init_level_multiplier))
print ("********* init_thrs_multiplier ********")
print (init_thrs_multiplier)
print ("Number of thrs: ", len(init_thrs_multiplier))
print (bcolors.WARNING + "Calculate levels and sort " + bcolors.ENDC)
print ("***********************")
level_codes = torch.tensor(init_level_multiplier)
# level_values =
print ("Level codes: \n", level_codes)
print ("Level codes: (shape): ", level_codes.size())
print ("Basic (shape): ", basis.size())
basis = basis.view(nbit,1)
level_values = torch.mm(level_codes, basis)
print ("Level values: \n", level_values)
print ("*******************")
print ("Level values (original): \n", level_values)
print ("Level values shape (original): ", level_values.size())
print (bcolors.ENDC + "\n")
print ("*******************")
level_values, level_indices = torch.topk(torch.transpose(level_values, 1, 0), k=num_levels)
print ("Ordered Level values: \n")
print (level_values)
print ("Ordered Level indices: \n")
print (level_indices)
print ("*******************")
level_values = torch.flip(level_values, dims=(-1, ))
level_indices = torch.flip(level_indices, dims=(-1, ))
print ("Inversed Level values: \n", level_values)
print ("Inversed Level indices (compare to level codes): \n", level_indices)
print ("*******************")
level_values = torch.transpose(level_values, 1, 0)
level_indices = torch.transpose(level_indices, 1, 0)
print ("Transposed level values: \n", level_values)
print ("Transposed level indices: \n", level_indices)
print (bcolors.WARNING + "************* " + bcolors.ENDC)
print (bcolors.WARNING + "Calculate threshold " + bcolors.ENDC)
print ("***********************")
thrs_multiplier = torch.tensor(init_thrs_multiplier)
thrs = torch.mm(thrs_multiplier, level_values)
print ("thrs: \n", thrs)
print (bcolors.WARNING + "************* " + bcolors.ENDC)
# x = torch.tensor([[0.78, 0.21, 0.66],[0.78, 0.21, 0.66]])
x = torch.randn(13,13,13) * 0.5 + 1
y = torch.zeros_like(x)
# zero_dims = torch.stack([x.view(-1).size(0), nbit])
# zero_dims = torch.tensor([[x.view(-1).size(0)], [nbit]])
# reshape_x = x.view(-1)
# zero_dims = [x.size(0), x.size(1), nbit]
zero_dims = [x.numel(), nbit]
bits_y = torch.ones(zero_dims).fill_(0.0)
zero_y = torch.zeros_like(x) # bias ?????????????????
zero_bits_y = torch.ones(zero_dims).fill_(0.0)
print ("x: ", x)
print ("y: ", y)
print ("zero_dims:" , zero_dims)
print ("bits_y: \n", bits_y)
for i in range(num_levels-1):
g = x > thrs[i]
y = torch.where(x > thrs[i], zero_y + level_values[i+1], y)
bits_y = torch.where((x > thrs[i]).view(-1,1), zero_bits_y + level_codes[level_indices[i+1]], bits_y)
print ("Loop: ",i, " ", g, " ", y, " \n", bits_y)
print (bcolors.WARNING + "************* " + bcolors.ENDC)
print (bcolors.FAIL + "Restore :" + bcolors.ENDC)
print ("Original value: ", x)
print ("basis.shape: ", basis.shape)
print ("bits_y.shape: ", bits_y.shape)
print ("Restore value: ", torch.mm(basis.view(1, -1), bits_y.T))
print ("Quantized y: ", y)
loss = torch.nn.MSELoss()
print ("quantized_loss:", loss(x, y))
print ("bits_y: \n", bits_y)
print (torch.mm(basis.view(1, -1), bits_y.T).view_as(x))
print ("****************************************************")
BT = bits_y.T
# calculate BTxB
print ("*********** calculate BTxB ***************")
BTxB = []
for i in range(nbit):
for j in range(nbit):
BTxBij = BT[i] * BT[j]
BTxBij = torch.sum(BTxBij)
BTxB.append(BTxBij)
BTxB = torch.stack(BTxB).view(nbit, nbit)
BTxB_inv = torch.inverse(BTxB)
print ("BTxB_inv: \n")
print (BTxB_inv)
# calculate BTxX
print ("*********** calculate BTxX ***************")
BTxX = []
for i in range(nbit):
BTxXi0 = BT[i] * x.view(-1)
BTxXi0 = torch.sum(BTxXi0)
BTxX.append(BTxXi0)
BTxX = torch.stack(BTxX).view(nbit, 1)
print("BTxX: \n", BTxX)
# calculate new basis
new_basis = torch.mm(BTxB_inv, BTxX)
print ("New basis: ", new_basis)
print ("Old basis: ", basis)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment