Skip to content

Instantly share code, notes, and snippets.

@phuocphn
Last active June 4, 2020 03:32
Show Gist options
  • Save phuocphn/ae0c7facb6be5937d16e24b4926f6a2e to your computer and use it in GitHub Desktop.
Save phuocphn/ae0c7facb6be5937d16e24b4926f6a2e to your computer and use it in GitHub Desktop.
import torch
import numpy as np
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(0)
torch.manual_seed(0)
def a_encode(x, nbit, basis):
num_levels = 2 ** nbit
init_level_multiplier = []
for i in range(0, num_levels):
level_multiplier_i = [0. for j in range(nbit)]
level_number = i
for j in range(nbit):
level_multiplier_i[j] = float(level_number % 2)
level_number = level_number // 2
init_level_multiplier.append(level_multiplier_i)
# initialize threshold multiplier
init_thrs_multiplier = []
for i in range(1, num_levels):
thrs_multiplier_i = [0. for j in range(num_levels)]
thrs_multiplier_i[i - 1] = 0.5
thrs_multiplier_i[i] = 0.5
init_thrs_multiplier.append(thrs_multiplier_i)
level_codes = torch.tensor(init_level_multiplier)
basis = basis.view(nbit,1)
level_values = torch.mm(level_codes, basis)
level_values, level_indices = torch.topk(torch.transpose(level_values, 1, 0), k=num_levels)
level_values = torch.flip(level_values, dims=(-1, ))
level_indices = torch.flip(level_indices, dims=(-1, ))
level_values = torch.transpose(level_values, 1, 0)
level_indices = torch.transpose(level_indices, 1, 0)
thrs_multiplier = torch.tensor(init_thrs_multiplier)
thrs = torch.mm(thrs_multiplier, level_values)
y = torch.zeros_like(x)
zero_dims = [x.numel(), nbit]
bits_y = torch.ones(zero_dims).fill_(0.0)
zero_y = torch.zeros_like(x) # bias ?????????????????
zero_bits_y = torch.ones(zero_dims).fill_(0.0)
for i in range(num_levels-1):
g = x > thrs[i]
y = torch.where(x > thrs[i], zero_y + level_values[i+1], y)
bits_y = torch.where((x > thrs[i]).view(-1,1), zero_bits_y + level_codes[level_indices[i+1]], bits_y)
return y, bits_y
def a_basis_regression(x, bits_y, nbit):
BT = bits_y.T
BTxB = []
for i in range(nbit):
for j in range(nbit):
BTxBij = BT[i] * BT[j]
BTxBij = torch.sum(BTxBij)
BTxB.append(BTxBij)
BTxB = torch.stack(BTxB).view(nbit, nbit)
BTxB_inv = torch.inverse(BTxB)
BTxX = []
for i in range(nbit):
BTxXi0 = BT[i] * x.view(-1)
BTxXi0 = torch.sum(BTxXi0)
BTxX.append(BTxXi0)
BTxX = torch.stack(BTxX).view(nbit, 1)
new_basis = torch.mm(BTxB_inv, BTxX)
return new_basis
def w_encode(x, nbit, basis):
num_levels = 2 ** nbit
out_channels = x.size(1)
# initialize level multiplier
init_level_multiplier = [] # ~ [-1, 1]
for i in range(num_levels):
level_multiplier_i = [0. for j in range(nbit)]
level_number = i
for j in range(nbit):
binary_code = level_number % 2
if binary_code == 0:
binary_code = -1
level_multiplier_i[j] = float(binary_code)
level_number = level_number // 2
init_level_multiplier.append(level_multiplier_i)
# initialize threshold multiplier
init_thrs_multiplier = []
for i in range(1, num_levels):
thrs_multiplier_i = [0. for j in range(num_levels)]
thrs_multiplier_i[i - 1] = 0.5
thrs_multiplier_i[i] = 0.5
init_thrs_multiplier.append(thrs_multiplier_i)
level_codes = torch.tensor(init_level_multiplier)
thrs_multiplier = torch.tensor(init_thrs_multiplier)
level_codes = torch.tensor(init_level_multiplier)
level_values = torch.mm(level_codes, basis)
level_values, level_indices = torch.topk(torch.transpose(level_values, 1, 0), k=num_levels)
level_values = torch.flip(level_values, dims=(-1, ))
level_indices = torch.flip(level_indices, dims=(-1, ))
level_values = torch.transpose(level_values, 1, 0)
level_indices = torch.transpose(level_indices, 1, 0)
# calculate threshold
thrs = torch.mm(thrs_multiplier, level_values)
# calculate level codes per channel
level_codes_channelwise_dims = [num_levels * out_channels, nbit] # (128* 2**3, 3)
level_codes_channelwise = torch.zeros(level_codes_channelwise_dims)
for i in range(num_levels):
eq = torch.eq(level_indices,i)
# print (level_codes_channelwise.size())
# print (level_indices.size())
level_codes_channelwise = torch.where(eq.view(-1, 1), level_codes_channelwise + level_codes[i], level_codes_channelwise)
level_codes_channelwise = level_codes_channelwise.view(num_levels, out_channels, nbit)
y = torch.zeros_like(x) + level_values[0].view(-1,1,1)
reshape_x = x.view(-1, out_channels)
zero_dims = [reshape_x.size(0) * out_channels, nbit]
bits_y = torch.ones(zero_dims).fill_(-1.)
zero_y = torch.zeros_like(x)
zero_bits_y = torch.zeros(zero_dims)
zero_bits_y = zero_bits_y.view(-1, out_channels, nbit)
for i in range(num_levels - 1):
g = torch.gt(x, thrs[i].view(-1, 1,1))
y = torch.where(g, zero_y + level_values[i + 1].view(-1, 1,1), y)#
bits_y = torch.where(g.view(-1,1), (zero_bits_y + level_codes_channelwise[i + 1]).view(-1, nbit) , bits_y)
bits_y = bits_y.view(-1, out_channels, nbit)
return y, bits_y
def w_basis_regression(x, bits_y, nbit):
# training
delta = 0.0001
out_channels = x.size(1)
reshape_x = x.view(-1, out_channels)
sum_multiplier = torch.ones(1, x.view(-1, out_channels).size(0)) # ---> one-array has shape [1, input_channels * kernel size * kernel size]
sum_multiplier_basis = torch.ones(1, nbit) # ---> one array has shape [1, nbit]
BT = bits_y.permute(2, 0, 1) # [7x7x3, 18, 3] ~~> [3, 7x7x3, 18]
# calculate BTxB
BTxB = []
for i in range(nbit):
for j in range(nbit):
BTxBij = BT[i]* BT[j]
BTxBij = torch.mm(sum_multiplier, BTxBij)
if i == j:
mat_one = torch.ones(1, out_channels)
BTxBij = BTxBij + (delta * mat_one) # + E
BTxB.append(BTxBij)
BTxB = torch.stack(BTxB).view(nbit, nbit, out_channels)
# calculate inverse of BTxB
if nbit > 2:
BTxB_transpose = BTxB.permute(2, 0, 1)
BTxB_inv = torch.inverse(BTxB_transpose)
BTxB_inv = BTxB_inv.permute(1, 2, 0)
elif nbit == 2:
det = BTxB[0][0]* BTxB[1][1] - BTxB[0][1]* BTxB[1][0]
inv = []
inv.append(BTxB[1][1] / det)
inv.append(-BTxB[0][1] / det)
inv.append(-BTxB[1][0] / det)
inv.append(BTxB[0][0] / det)
BTxB_inv = torch.stack(inv).view(nbit, nbit, out_channels)
elif nbit == 1:
BTxB_inv = torch.reciprocal(BTxB)
# calculate BTxX
BTxX = []
for i in range(nbit):
BTxXi0 = BT[i] * reshape_x
BTxXi0 = torch.mm(sum_multiplier,BTxXi0)
BTxX.append(BTxXi0)
BTxX = torch.stack(BTxX).view(nbit, out_channels)
BTxX = BTxX + (delta * basis) # + basis
# calculate new basis
new_basis = []
for i in range(nbit):
new_basis_i = BTxB_inv[i] * BTxX
new_basis_i = torch.mm(sum_multiplier_basis, new_basis_i)
new_basis.append(new_basis_i)
new_basis = torch.stack(new_basis).view(nbit, out_channels)
return new_basis
print ("*************** Approximation activations **************** ")
nbit = 10
NORM_PPF_0_75 = 0.6745
MOVING_AVERAGES_FACTOR = 0.4
loss = torch.nn.MSELoss()
basis = torch.tensor([(NORM_PPF_0_75 * 2 / (2 ** nbit - 1)) * (2. ** i) for i in range(nbit)])
x = torch.randn(3,3,3) * 0.5 + 1
for _ in range(10):
quantized_x, bits_x = a_encode(x, nbit, basis)
print ("Loss: ", loss(quantized_x, x))
new_basis = a_basis_regression(x, bits_x, nbit).squeeze(-1)
basis = basis * MOVING_AVERAGES_FACTOR + new_basis * (1 - MOVING_AVERAGES_FACTOR)
print (x[0][-1])
print ("~~~~~~~~~~~~~~")
print (quantized_x[0][-1])
print ("*************** Approximation weights **************** ")
w = torch.randn(3,18,7,7) * 0.1
n = 7 * 7 * 18
base = NORM_PPF_0_75 * ((2. / n) ** 0.5) / (2 ** (nbit - 1))
init_basis = []
for j in range(nbit):
init_basis.append([(2 ** j) * base for i in range(w.size(1))])
basis = torch.tensor(init_basis)
for _ in range(10):
quantized_w, bits_w = w_encode(w, nbit, basis)
print ("Loss: ", loss(quantized_w, w))
new_basis = w_basis_regression(w, bits_w, nbit)
basis = basis * MOVING_AVERAGES_FACTOR + new_basis * (1 - MOVING_AVERAGES_FACTOR)
print (w[0][-1])
print ("~~~~~~~~~~~~~~")
print (quantized_w[0][-1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment