Last active
June 4, 2020 03:32
-
-
Save phuocphn/ae0c7facb6be5937d16e24b4926f6a2e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
np.random.seed(0) | |
torch.manual_seed(0) | |
def a_encode(x, nbit, basis): | |
num_levels = 2 ** nbit | |
init_level_multiplier = [] | |
for i in range(0, num_levels): | |
level_multiplier_i = [0. for j in range(nbit)] | |
level_number = i | |
for j in range(nbit): | |
level_multiplier_i[j] = float(level_number % 2) | |
level_number = level_number // 2 | |
init_level_multiplier.append(level_multiplier_i) | |
# initialize threshold multiplier | |
init_thrs_multiplier = [] | |
for i in range(1, num_levels): | |
thrs_multiplier_i = [0. for j in range(num_levels)] | |
thrs_multiplier_i[i - 1] = 0.5 | |
thrs_multiplier_i[i] = 0.5 | |
init_thrs_multiplier.append(thrs_multiplier_i) | |
level_codes = torch.tensor(init_level_multiplier) | |
basis = basis.view(nbit,1) | |
level_values = torch.mm(level_codes, basis) | |
level_values, level_indices = torch.topk(torch.transpose(level_values, 1, 0), k=num_levels) | |
level_values = torch.flip(level_values, dims=(-1, )) | |
level_indices = torch.flip(level_indices, dims=(-1, )) | |
level_values = torch.transpose(level_values, 1, 0) | |
level_indices = torch.transpose(level_indices, 1, 0) | |
thrs_multiplier = torch.tensor(init_thrs_multiplier) | |
thrs = torch.mm(thrs_multiplier, level_values) | |
y = torch.zeros_like(x) | |
zero_dims = [x.numel(), nbit] | |
bits_y = torch.ones(zero_dims).fill_(0.0) | |
zero_y = torch.zeros_like(x) # bias ????????????????? | |
zero_bits_y = torch.ones(zero_dims).fill_(0.0) | |
for i in range(num_levels-1): | |
g = x > thrs[i] | |
y = torch.where(x > thrs[i], zero_y + level_values[i+1], y) | |
bits_y = torch.where((x > thrs[i]).view(-1,1), zero_bits_y + level_codes[level_indices[i+1]], bits_y) | |
return y, bits_y | |
def a_basis_regression(x, bits_y, nbit): | |
BT = bits_y.T | |
BTxB = [] | |
for i in range(nbit): | |
for j in range(nbit): | |
BTxBij = BT[i] * BT[j] | |
BTxBij = torch.sum(BTxBij) | |
BTxB.append(BTxBij) | |
BTxB = torch.stack(BTxB).view(nbit, nbit) | |
BTxB_inv = torch.inverse(BTxB) | |
BTxX = [] | |
for i in range(nbit): | |
BTxXi0 = BT[i] * x.view(-1) | |
BTxXi0 = torch.sum(BTxXi0) | |
BTxX.append(BTxXi0) | |
BTxX = torch.stack(BTxX).view(nbit, 1) | |
new_basis = torch.mm(BTxB_inv, BTxX) | |
return new_basis | |
def w_encode(x, nbit, basis): | |
num_levels = 2 ** nbit | |
out_channels = x.size(1) | |
# initialize level multiplier | |
init_level_multiplier = [] # ~ [-1, 1] | |
for i in range(num_levels): | |
level_multiplier_i = [0. for j in range(nbit)] | |
level_number = i | |
for j in range(nbit): | |
binary_code = level_number % 2 | |
if binary_code == 0: | |
binary_code = -1 | |
level_multiplier_i[j] = float(binary_code) | |
level_number = level_number // 2 | |
init_level_multiplier.append(level_multiplier_i) | |
# initialize threshold multiplier | |
init_thrs_multiplier = [] | |
for i in range(1, num_levels): | |
thrs_multiplier_i = [0. for j in range(num_levels)] | |
thrs_multiplier_i[i - 1] = 0.5 | |
thrs_multiplier_i[i] = 0.5 | |
init_thrs_multiplier.append(thrs_multiplier_i) | |
level_codes = torch.tensor(init_level_multiplier) | |
thrs_multiplier = torch.tensor(init_thrs_multiplier) | |
level_codes = torch.tensor(init_level_multiplier) | |
level_values = torch.mm(level_codes, basis) | |
level_values, level_indices = torch.topk(torch.transpose(level_values, 1, 0), k=num_levels) | |
level_values = torch.flip(level_values, dims=(-1, )) | |
level_indices = torch.flip(level_indices, dims=(-1, )) | |
level_values = torch.transpose(level_values, 1, 0) | |
level_indices = torch.transpose(level_indices, 1, 0) | |
# calculate threshold | |
thrs = torch.mm(thrs_multiplier, level_values) | |
# calculate level codes per channel | |
level_codes_channelwise_dims = [num_levels * out_channels, nbit] # (128* 2**3, 3) | |
level_codes_channelwise = torch.zeros(level_codes_channelwise_dims) | |
for i in range(num_levels): | |
eq = torch.eq(level_indices,i) | |
# print (level_codes_channelwise.size()) | |
# print (level_indices.size()) | |
level_codes_channelwise = torch.where(eq.view(-1, 1), level_codes_channelwise + level_codes[i], level_codes_channelwise) | |
level_codes_channelwise = level_codes_channelwise.view(num_levels, out_channels, nbit) | |
y = torch.zeros_like(x) + level_values[0].view(-1,1,1) | |
reshape_x = x.view(-1, out_channels) | |
zero_dims = [reshape_x.size(0) * out_channels, nbit] | |
bits_y = torch.ones(zero_dims).fill_(-1.) | |
zero_y = torch.zeros_like(x) | |
zero_bits_y = torch.zeros(zero_dims) | |
zero_bits_y = zero_bits_y.view(-1, out_channels, nbit) | |
for i in range(num_levels - 1): | |
g = torch.gt(x, thrs[i].view(-1, 1,1)) | |
y = torch.where(g, zero_y + level_values[i + 1].view(-1, 1,1), y)# | |
bits_y = torch.where(g.view(-1,1), (zero_bits_y + level_codes_channelwise[i + 1]).view(-1, nbit) , bits_y) | |
bits_y = bits_y.view(-1, out_channels, nbit) | |
return y, bits_y | |
def w_basis_regression(x, bits_y, nbit): | |
# training | |
delta = 0.0001 | |
out_channels = x.size(1) | |
reshape_x = x.view(-1, out_channels) | |
sum_multiplier = torch.ones(1, x.view(-1, out_channels).size(0)) # ---> one-array has shape [1, input_channels * kernel size * kernel size] | |
sum_multiplier_basis = torch.ones(1, nbit) # ---> one array has shape [1, nbit] | |
BT = bits_y.permute(2, 0, 1) # [7x7x3, 18, 3] ~~> [3, 7x7x3, 18] | |
# calculate BTxB | |
BTxB = [] | |
for i in range(nbit): | |
for j in range(nbit): | |
BTxBij = BT[i]* BT[j] | |
BTxBij = torch.mm(sum_multiplier, BTxBij) | |
if i == j: | |
mat_one = torch.ones(1, out_channels) | |
BTxBij = BTxBij + (delta * mat_one) # + E | |
BTxB.append(BTxBij) | |
BTxB = torch.stack(BTxB).view(nbit, nbit, out_channels) | |
# calculate inverse of BTxB | |
if nbit > 2: | |
BTxB_transpose = BTxB.permute(2, 0, 1) | |
BTxB_inv = torch.inverse(BTxB_transpose) | |
BTxB_inv = BTxB_inv.permute(1, 2, 0) | |
elif nbit == 2: | |
det = BTxB[0][0]* BTxB[1][1] - BTxB[0][1]* BTxB[1][0] | |
inv = [] | |
inv.append(BTxB[1][1] / det) | |
inv.append(-BTxB[0][1] / det) | |
inv.append(-BTxB[1][0] / det) | |
inv.append(BTxB[0][0] / det) | |
BTxB_inv = torch.stack(inv).view(nbit, nbit, out_channels) | |
elif nbit == 1: | |
BTxB_inv = torch.reciprocal(BTxB) | |
# calculate BTxX | |
BTxX = [] | |
for i in range(nbit): | |
BTxXi0 = BT[i] * reshape_x | |
BTxXi0 = torch.mm(sum_multiplier,BTxXi0) | |
BTxX.append(BTxXi0) | |
BTxX = torch.stack(BTxX).view(nbit, out_channels) | |
BTxX = BTxX + (delta * basis) # + basis | |
# calculate new basis | |
new_basis = [] | |
for i in range(nbit): | |
new_basis_i = BTxB_inv[i] * BTxX | |
new_basis_i = torch.mm(sum_multiplier_basis, new_basis_i) | |
new_basis.append(new_basis_i) | |
new_basis = torch.stack(new_basis).view(nbit, out_channels) | |
return new_basis | |
print ("*************** Approximation activations **************** ") | |
nbit = 10 | |
NORM_PPF_0_75 = 0.6745 | |
MOVING_AVERAGES_FACTOR = 0.4 | |
loss = torch.nn.MSELoss() | |
basis = torch.tensor([(NORM_PPF_0_75 * 2 / (2 ** nbit - 1)) * (2. ** i) for i in range(nbit)]) | |
x = torch.randn(3,3,3) * 0.5 + 1 | |
for _ in range(10): | |
quantized_x, bits_x = a_encode(x, nbit, basis) | |
print ("Loss: ", loss(quantized_x, x)) | |
new_basis = a_basis_regression(x, bits_x, nbit).squeeze(-1) | |
basis = basis * MOVING_AVERAGES_FACTOR + new_basis * (1 - MOVING_AVERAGES_FACTOR) | |
print (x[0][-1]) | |
print ("~~~~~~~~~~~~~~") | |
print (quantized_x[0][-1]) | |
print ("*************** Approximation weights **************** ") | |
w = torch.randn(3,18,7,7) * 0.1 | |
n = 7 * 7 * 18 | |
base = NORM_PPF_0_75 * ((2. / n) ** 0.5) / (2 ** (nbit - 1)) | |
init_basis = [] | |
for j in range(nbit): | |
init_basis.append([(2 ** j) * base for i in range(w.size(1))]) | |
basis = torch.tensor(init_basis) | |
for _ in range(10): | |
quantized_w, bits_w = w_encode(w, nbit, basis) | |
print ("Loss: ", loss(quantized_w, w)) | |
new_basis = w_basis_regression(w, bits_w, nbit) | |
basis = basis * MOVING_AVERAGES_FACTOR + new_basis * (1 - MOVING_AVERAGES_FACTOR) | |
print (w[0][-1]) | |
print ("~~~~~~~~~~~~~~") | |
print (quantized_w[0][-1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment