Skip to content

Instantly share code, notes, and snippets.

@phuocphn
Created June 3, 2020 07:58
Show Gist options
  • Save phuocphn/05fe0cd8e479135d77ec5301e2d6e0b3 to your computer and use it in GitHub Desktop.
Save phuocphn/05fe0cd8e479135d77ec5301e2d6e0b3 to your computer and use it in GitHub Desktop.
import torch
import numpy as np
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(0)
torch.manual_seed(0)
def encode(x, nbit, basis):
init_level_multiplier = []
for i in range(0, num_levels):
level_multiplier_i = [0. for j in range(nbit)]
level_number = i
for j in range(nbit):
level_multiplier_i[j] = float(level_number % 2)
level_number = level_number // 2
init_level_multiplier.append(level_multiplier_i)
# initialize threshold multiplier
init_thrs_multiplier = []
for i in range(1, num_levels):
thrs_multiplier_i = [0. for j in range(num_levels)]
thrs_multiplier_i[i - 1] = 0.5
thrs_multiplier_i[i] = 0.5
init_thrs_multiplier.append(thrs_multiplier_i)
level_codes = torch.tensor(init_level_multiplier)
basis = basis.view(nbit,1)
level_values = torch.mm(level_codes, basis)
level_values, level_indices = torch.topk(torch.transpose(level_values, 1, 0), k=num_levels)
level_values = torch.flip(level_values, dims=(-1, ))
level_indices = torch.flip(level_indices, dims=(-1, ))
level_values = torch.transpose(level_values, 1, 0)
level_indices = torch.transpose(level_indices, 1, 0)
thrs_multiplier = torch.tensor(init_thrs_multiplier)
thrs = torch.mm(thrs_multiplier, level_values)
y = torch.zeros_like(x)
zero_dims = [x.numel(), nbit]
bits_y = torch.ones(zero_dims).fill_(0.0)
zero_y = torch.zeros_like(x) # bias ?????????????????
zero_bits_y = torch.ones(zero_dims).fill_(0.0)
for i in range(num_levels-1):
g = x > thrs[i]
y = torch.where(x > thrs[i], zero_y + level_values[i+1], y)
bits_y = torch.where((x > thrs[i]).view(-1,1), zero_bits_y + level_codes[level_indices[i+1]], bits_y)
return y, bits_y
def basis_linear_regression(bits_y, nbit):
BT = bits_y.T
BTxB = []
for i in range(nbit):
for j in range(nbit):
BTxBij = BT[i] * BT[j]
BTxBij = torch.sum(BTxBij)
BTxB.append(BTxBij)
BTxB = torch.stack(BTxB).view(nbit, nbit)
BTxB_inv = torch.inverse(BTxB)
BTxX = []
for i in range(nbit):
BTxXi0 = BT[i] * x.view(-1)
BTxXi0 = torch.sum(BTxXi0)
BTxX.append(BTxXi0)
BTxX = torch.stack(BTxX).view(nbit, 1)
new_basis = torch.mm(BTxB_inv, BTxX)
return new_basis
NORM_PPF_0_75 = 0.6745
MOVING_AVERAGES_FACTOR = 0.4
nbit = 8
num_levels = 2 ** nbit
basis = torch.tensor([(NORM_PPF_0_75 * 2 / (2 ** nbit - 1)) * (2. ** i) for i in range(nbit)])
x = torch.randn(3,3,3) * 0.5 + 1
loss = torch.nn.MSELoss()
for _ in range(10):
quantized_x, bits_x = encode(x, nbit, basis)
#print ("Basis vector: ", basis)
#print ("Original value: ", x)
#print ("Quantized value: ", quantized_x)
print ("Loss: ", loss(quantized_x, x))
new_basis = basis_linear_regression(bits_x, nbit).squeeze(-1)
basis = basis * MOVING_AVERAGES_FACTOR + new_basis * (1 - MOVING_AVERAGES_FACTOR)
print ("New basis: ", basis)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment