Skip to content

Instantly share code, notes, and snippets.

@XinDongol
Last active April 25, 2019 11:49
Show Gist options
  • Save XinDongol/49b67beabf75b0fe3d2522120aa2dab9 to your computer and use it in GitHub Desktop.
Save XinDongol/49b67beabf75b0fe3d2522120aa2dab9 to your computer and use it in GitHub Desktop.
from torch.autograd import Variable
import torch
from torch import nn
from collections import OrderedDict
from IPython import embed
from torch.autograd.function import InplaceFunction, Function
import torch.nn.functional as F
import math
def _mean(p, dim):
"""Computes the mean over all dimensions except dim"""
if dim is None:
return p.mean()
elif dim == 0:
output_size = (p.size(0),) + (1,) * (p.dim() - 1)
return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size)
elif dim == p.dim() - 1:
output_size = (1,) * (p.dim() - 1) + (p.size(-1),)
return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size)
else:
return _mean(p.transpose(0, dim), 0).transpose(0, dim)
class UniformQuantize(InplaceFunction):
@classmethod
def forward(cls, ctx, input, num_bits=8, min_value=None, max_value=None,
stochastic=False, inplace=False, enforce_true_zero=False, num_chunks=None, out_half=False):
num_chunks = input.shape[0] if num_chunks is None else num_chunks
if min_value is None or max_value is None:
B = input.shape[0]
y = input.view(B // num_chunks, -1)
if min_value is None:
min_value = y.min(-1)[0].mean(-1) # C
#min_value = float(input.view(input.size(0), -1).min(-1)[0].mean())
if max_value is None:
#max_value = float(input.view(input.size(0), -1).max(-1)[0].mean())
max_value = y.max(-1)[0].mean(-1) # C
ctx.inplace = inplace
ctx.num_bits = num_bits
ctx.min_value = min_value
ctx.max_value = max_value
ctx.stochastic = stochastic
if ctx.inplace:
ctx.mark_dirty(input)
output = input
else:
output = input.clone()
qmin = 0.
qmax = 2.**num_bits - 1.
#import pdb; pdb.set_trace()
scale = (max_value - min_value) / (qmax - qmin)
scale = max(scale, 1e-8)
if enforce_true_zero:
initial_zero_point = qmin - min_value / scale
zero_point = 0.
# make zero exactly represented
if initial_zero_point < qmin:
zero_point = qmin
elif initial_zero_point > qmax:
zero_point = qmax
else:
zero_point = initial_zero_point
zero_point = int(zero_point)
output.div_(scale).add_(zero_point)
else:
output.add_(-min_value).div_(scale).add_(qmin)
if ctx.stochastic:
noise = output.new(output.shape).uniform_(-0.5, 0.5)
output.add_(noise)
output.clamp_(qmin, qmax).round_() # quantize
if enforce_true_zero:
output.add_(-zero_point).mul_(scale) # dequantize
else:
output.add_(-qmin).mul_(scale).add_(min_value) # dequantize
if out_half and num_bits <= 16:
output = output.half()
return output
@staticmethod
def backward(ctx, grad_output):
# straight-through estimator
grad_input = grad_output
return grad_input, None, None, None, None, None, None
class UniformQuantizeGrad(InplaceFunction):
@classmethod
def forward(cls, ctx, input, num_bits=8, min_value=None, max_value=None, stochastic=True, inplace=False):
ctx.inplace = inplace
ctx.num_bits = num_bits
ctx.min_value = min_value
ctx.max_value = max_value
ctx.stochastic = stochastic
return input
@staticmethod
def backward(ctx, grad_output):
if ctx.min_value is None:
min_value = float(grad_output.min())
# min_value = float(grad_output.view(
# grad_output.size(0), -1).min(-1)[0].mean())
else:
min_value = ctx.min_value
if ctx.max_value is None:
max_value = float(grad_output.max())
# max_value = float(grad_output.view(
# grad_output.size(0), -1).max(-1)[0].mean())
else:
max_value = ctx.max_value
grad_input = UniformQuantize().apply(grad_output, ctx.num_bits,
min_value, max_value, ctx.stochastic, ctx.inplace)
return grad_input, None, None, None, None, None
def quantize(x, num_bits=8, min_value=None, max_value=None, num_chunks=None, stochastic=False, inplace=False):
return UniformQuantize().apply(x, num_bits, min_value, max_value, num_chunks, stochastic, inplace)
def quantize_grad(x, num_bits=8, min_value=None, max_value=None, stochastic=True, inplace=False):
return UniformQuantizeGrad().apply(x, num_bits, min_value, max_value, stochastic, inplace)
class QuantMeasure(nn.Module):
"""docstring for QuantMeasure."""
def __init__(self, num_bits=8, momentum=0.1, init_running_min = -2.0, init_running_max = 2.0):
super(QuantMeasure, self).__init__()
self.register_buffer('running_min', torch.zeros(1))
self.register_buffer('running_max', torch.zeros(1))
self.momentum = momentum
self.num_bits = num_bits
self.running_min = init_running_min
self.running_max = init_running_max
def forward(self, input):
if self.training:
min_value = input.detach().view(
input.size(0), -1).min(-1)[0].mean()
max_value = input.detach().view(
input.size(0), -1).max(-1)[0].mean()
self.running_min.mul_(self.momentum).add_(
min_value * (1 - self.momentum))
self.running_max.mul_(self.momentum).add_(
max_value * (1 - self.momentum))
else:
min_value = self.running_min
max_value = self.running_max
return quantize(input, self.num_bits, min_value=float(min_value), max_value=float(max_value), num_chunks=16)
class QConv2d(nn.Conv2d):
"""docstring for QConv2d."""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, groups=1, bias=True, num_bits=8, num_bits_weight=None, num_bits_grad=None, biprecision=False):
super(QConv2d, self).__init__(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, bias)
self.num_bits = num_bits
self.num_bits_weight = num_bits_weight or num_bits
self.num_bits_grad = num_bits_grad
self.quantize_input = QuantMeasure(self.num_bits)
self.biprecision = biprecision
def forward(self, input):
qinput = self.quantize_input(input)
qweight = quantize(self.weight, num_bits=self.num_bits_weight,
min_value=float(self.weight.min()),
max_value=float(self.weight.max()))
if self.bias is not None:
qbias = quantize(self.bias, num_bits=self.num_bits_weight)
else:
qbias = None
if not self.biprecision or self.num_bits_grad is None:
output = F.conv2d(qinput, qweight, qbias, self.stride,
self.padding, self.dilation, self.groups)
if self.num_bits_grad is not None:
output = quantize_grad(output, num_bits=self.num_bits_grad)
else:
output = conv2d_biprec(qinput, qweight, qbias, self.stride,
self.padding, self.dilation, self.groups, num_bits_grad=self.num_bits_grad)
return output
def compute_integral_part(input, overflow_rate):
abs_value = input.abs().view(-1)
sorted_value = abs_value.sort(dim=0, descending=True)[0]
split_idx = int(overflow_rate * len(sorted_value))
v = sorted_value[split_idx]
if isinstance(v, Variable):
v = v.data.cpu().numpy()[0]
sf = math.ceil(math.log2(v+1e-12))
return sf
def linear_quantize(input, sf, bits):
assert bits >= 1, bits
if bits == 1:
return torch.sign(input) - 1
delta = math.pow(2.0, -sf)
bound = math.pow(2.0, bits-1)
min_val = - bound
max_val = bound - 1
rounded = torch.floor(input / delta + 0.5)
clipped_value = torch.clamp(rounded, min_val, max_val) * delta
return clipped_value
def log_minmax_quantize(input, bits):
assert bits >= 1, bits
if bits == 1:
return torch.sign(input), 0.0, 0.0
s = torch.sign(input)
input0 = torch.log(torch.abs(input) + 1e-20)
v = min_max_quantize(input0, bits)
v = torch.exp(v) * s
return v
def log_linear_quantize(input, sf, bits):
assert bits >= 1, bits
if bits == 1:
return torch.sign(input), 0.0, 0.0
s = torch.sign(input)
input0 = torch.log(torch.abs(input) + 1e-20)
v = linear_quantize(input0, sf, bits)
v = torch.exp(v) * s
return v
def min_max_quantize(input, bits):
assert bits >= 1, bits
if bits == 1:
return torch.sign(input) - 1
min_val, max_val = input.min(), input.max()
if isinstance(min_val, Variable):
max_val = float(max_val.data.cpu().numpy()[0])
min_val = float(min_val.data.cpu().numpy()[0])
input_rescale = (input - min_val) / (max_val - min_val)
n = math.pow(2.0, bits) - 1
v = torch.floor(input_rescale * n + 0.5) / n
v = v * (max_val - min_val) + min_val
return v
def tanh_quantize(input, bits):
assert bits >= 1, bits
if bits == 1:
return torch.sign(input)
input = torch.tanh(input) # [-1, 1]
input_rescale = (input + 1.0) / 2 #[0, 1]
n = math.pow(2.0, bits) - 1
v = torch.floor(input_rescale * n + 0.5) / n
v = 2 * v - 1 # [-1, 1]
v = 0.5 * torch.log((1 + v) / (1 - v)) # arctanh
return v
import torch.nn as nn
class My_Layer(nn.Module):
def __init__(self, m, num_bits):
super(My_Layer, self).__init__()
self.quantize_input = QuantMeasure(num_bits)
self.m = m
def forward(self, x):
#print('hahhahaha!')
x = self.quantize_input(x)
x = self.m(x)
return x
def replace(model):
if len(model._modules)==0:
return model
else:
for i, m in model._modules.items():
if if_target_layer:
model._modules[i] = My_Layer(m, 8)
else:
m = replace(m)
#replace(model)
def if_target_layer(m):
'''
define your own identify function
'''
return isinstance(m, nn.Linear)
def duplicate_model_with_quant(model, **kwargs):
"""assume that original model has at least a nn.Sequential
you can use this function to implement different scheme
"""
assert kwargs['type'] in ['linear', 'minmax', 'log', 'tanh']
#print(kwargs)
if isinstance(model, nn.Sequential):
# if it is sequence, we build a new sequence upon it
l = OrderedDict() # the new sequence
for k, v in model._modules.items(): # inside the sequence
if not isinstance(v, nn.Sequential):
l[k] = v # add the original layer into this new sequence
if isinstance(v, nn.ReLU):
quant_layer = nn.Tanh()
l['{}_{}_quant'.format(k, kwargs['type'])] = quant_layer # add the new layer into the new sequence
else:
l[k] = duplicate_model_with_quant(v, **kwargs)
m = nn.Sequential(l)
return m
else:
# if not sequence, go deeper to search
for k, v in model._modules.items():
model._modules[k] = duplicate_model_with_quant(v, **kwargs)
return model
class BinOp():
def __init__(self, model):
'''
parameters:
self.saved_params: list of parameters (float)
self.target_modules: list of parameters (quantised)
self.num_of_params: len of self.target_modules
self.quant_layer_list: list of layers (which need quantization)
'''
self.model = model
self.saved_params = []
self.target_modules = []
self.get_quant_layer()
self.num_of_params = len(self.quant_layer_list)
for m in self.quant_layer_list:
tmp = m.weight.data.clone()
self.saved_params.append(tmp)
self.target_modules.append(m.weight)
def get_quant_layer(self):
'''
output: self.quant_layer_list
use this func to gather layers which are needed to be quantized
'''
self.quant_layer_list = []
for index, m in enumerate(self.model.modules()): # scane all layers
if isinstance(m, nn.Conv2d): # choose what layers you want to quant
self.quant_layer_list.append(m)
print('Quantizing the ', index,' layer: ', m)
else:
print('Ignoring the ', index,' layer: ', m)
print('Total quantized layers: ', len(self.quant_layer_list))
return self.quant_layer_list
def binarization(self):
self.meancenterConvParams()
self.clampConvParams()
self.save_params()
self.binarizeConvParams()
def meancenterConvParams(self):
for index in range(self.num_of_params):
s = self.target_modules[index].data.size()
negMean = self.target_modules[index].data.mean(1, keepdim=True).\
mul(-1).expand_as(self.target_modules[index].data)
self.target_modules[index].data = self.target_modules[index].data.add(negMean)
def clampConvParams(self):
for index in range(self.num_of_params):
self.target_modules[index].data.clamp(-1.0, 1.0,
out = self.target_modules[index].data)
def save_params(self):
for index in range(self.num_of_params):
self.saved_params[index].copy_(self.target_modules[index].data)
def binarizeConvParams(self):
for index in range(self.num_of_params):
n = self.target_modules[index].data[0].nelement()
s = self.target_modules[index].data.size()
m = self.target_modules[index].data.norm(1, 3, keepdim=True)\
.sum(2, keepdim=True).sum(1, keepdim=True).div(n)
self.target_modules[index].data.sign()\
.mul(m.expand(s), out=self.target_modules[index].data)
def restore(self):
for index in range(self.num_of_params):
self.target_modules[index].data.copy_(self.saved_params[index])
def updateBinaryGradWeight(self):
for index in range(self.num_of_params):
weight = self.target_modules[index].data
n = weight[0].nelement()
s = weight.size()
m = weight.norm(1, 3, keepdim=True)\
.sum(2, keepdim=True).sum(1, keepdim=True).div(n).expand(s)
m[weight.lt(-1.0)] = 0
m[weight.gt(1.0)] = 0
# m = m.add(1.0/n).mul(1.0-1.0/s[1]).mul(n)
# self.target_modules[index].grad.data = \
# self.target_modules[index].grad.data.mul(m)
m = m.mul(self.target_modules[index].grad.data)
m_add = weight.sign().mul(self.target_modules[index].grad.data)
m_add = m_add.sum(3, keepdim=True)\
.sum(2, keepdim=True).sum(1, keepdim=True).div(n).expand(s)
m_add = m_add.mul(weight.sign())
self.target_modules[index].grad.data = m.add(m_add).mul(1.0-1.0/s[1]).mul(n)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment