Skip to content

Instantly share code, notes, and snippets.

@TeaPoly
Last active June 14, 2024 07:44
Show Gist options
  • Save TeaPoly/9b4815202c17aea1c0b56c675ceb2ad3 to your computer and use it in GitHub Desktop.
Save TeaPoly/9b4815202c17aea1c0b56c675ceb2ad3 to your computer and use it in GitHub Desktop.
#!/usr/bin/python
# -*- coding: utf-8 -*-
import torch
def round_ste(x):
# STE for gradient
return torch.floor(x + 0.5).detach() + (x - x.detach())
def round_clip(x, prec):
x = round_ste(x)
x = torch.clamp(x, 0, 2.**prec - 1)
return x
def scale_and_min(x, prec, axis, clipping=1.0, stop_gradient_scale=False):
min_val = torch.amin(x, dim=axis, keepdim=True)
max_val = torch.amax(x, dim=axis, keepdim=True)
min_val = min_val * clipping
max_val = max_val * clipping
scale = (max_val - min_val) / (2.**prec - 1)
if stop_gradient_scale:
scale = scale.detach()
return scale, min_val
def quantize(x, prec, axis, clipping=1.0, stop_gradient_scale=False):
scale, min_val = scale_and_min(
x, prec, axis, clipping, stop_gradient_scale)
x = x - min_val
x = torch.where(scale != 0, torch.divide(x, scale), torch.zeros_like(x))
qx = round_clip(x, prec)
return qx, scale, min_val
def dequantize(qx, scale, min_val):
deqx = qx * scale
deqx = deqx + min_val
return deqx
class I2WasymScLinear(torch.nn.Linear):
def __init__(self, in_features, out_features, bias=True, **kwargs):
super(I2WasymScLinear, self).__init__(in_features, out_features, bias)
def _quant_weight(self):
qx, scale, min_val = quantize(
self.weight, prec=2, axis=1, clipping=1.0, stop_gradient_scale=False)
return dequantize(qx, scale, min_val)
def forward(self, input):
return torch.nn.functional.linear(input, self._quant_weight(), bias=self.bias)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment