Skip to content

Instantly share code, notes, and snippets.

@TeaPoly
Last active June 14, 2024 07:43
Show Gist options
  • Save TeaPoly/5b4a9a9dae891bbdc0530ab055c33da2 to your computer and use it in GitHub Desktop.
Save TeaPoly/5b4a9a9dae891bbdc0530ab055c33da2 to your computer and use it in GitHub Desktop.
#!/usr/bin/python
# -*- coding: utf-8 -*-
import torch
def absmean_binarize(x, contract_dims, centralize=False, eps=1e-8):
if centralize:
mean = torch.mean(x, dim=contract_dims, keepdim=True)
x = x - mean
x = torch.where(x == 0.0, torch.tensor(eps, device=x.device), x)
scale = torch.mean(torch.abs(x), dim=contract_dims, keepdim=True)
x = ste(x, torch.sign)
return x, scale
def ste(x, fn):
return x - x.detach() + fn(x).detach()
class BinarizedLinear(torch.nn.Linear):
def __init__(self, in_features, out_features, bias=True, **kwargs):
super(BinarizedLinear, self).__init__(in_features, out_features, bias)
def _quant_weight(self):
binarized, scale = absmean_binarize(
self.weight, 1, centralize=False, eps=1e-8
)
return binarized * scale
def forward(self, input):
return torch.nn.functional.linear(input, self._quant_weight(), bias=self.bias)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment