Skip to content

Instantly share code, notes, and snippets.

@AranKomat
Created March 20, 2020 00:00
Show Gist options
  • Save AranKomat/3d5fbac5473aeb3303e92d13923424e3 to your computer and use it in GitHub Desktop.
Save AranKomat/3d5fbac5473aeb3303e92d13923424e3 to your computer and use it in GitHub Desktop.
class PN_(torch.autograd.Function):
def __init__(self):
super(PN_, self).__init__()
@staticmethod
def forward(ctx, x, states): # x = [b, l, d]
eps, psi, nu = states
x_hat = x/(psi+eps)
ctx.save_for_backward(x_hat, eps, psi, nu)
return x_hat
@staticmethod
def backward(ctx, grad):
x_hat, eps, psi, nu = ctx.saved_tensors
grad_x = (grad - nu * x_hat) / (psi+eps)
return grad_x, None
class PN(nn.Module):
def __init__(self, features, fp32=True, eps=1.0e-5):
super(PN, self).__init__()
self.fp32 = fp32 # True if we compute in fp32 for PN. Assume that we use fp16 elsewhere
# TODO: take care of params being fp16
self.features = features
std = math.sqrt(1 / features)
self.gamma = nn.Parameter(torch.Tensor(features))
self.beta = nn.Parameter(torch.Tensor(features))
self.alpha_f = 0.95 # tune
self.alpha_b = 0.95 # tune
self.register_buffer('psi', torch.ones(features))
self.register_buffer('nu', torch.zeros(features))
self.register_buffer('eps', torch.zeros(1).fill_(eps))
self.Gamma = None
self.reset_parameters()
def reset_parameters(self):
init.uniform_(self.gamma)
init.zeros_(self.beta)
def forward(self, x):
x_dtype = x.dtype
if self.fp32:
x = x.float()
x_size = list(x.size())
x = x.view(-1, x_size[-1])
if self.training:
if self.Gamma is not None:
with torch.no_grad():
Lambda = (self.x_hat_grad * self.x_hat).mean(0)
self.nu = self.nu * (1 - (1 - self.alpha_b) * self.Gamma) + (1 - self.alpha_b) * Lambda
del self.x_hat, self.x_hat_grad
def extract(grad):
self.x_hat_grad = grad.detach().clone()
print(grad)
psi_b = (x ** 2).mean(0)
x_hat = PN_.apply(x, (self.eps, self.psi, self.nu))
self.x_hat = x_hat.detach().clone()
x_hat.requires_grad_()
x_hat.register_hook(extract)
self.Gamma = (x_hat ** 2).mean(0).detach()
y = self.gamma * x_hat + self.beta
print(y, psi_b, x_hat)
self.psi = torch.sqrt(self.alpha_f * (self.psi ** 2) + (1 - self.alpha_f) * (psi_b ** 2)).detach()
else:
y = self.gamma * x / (self.psi+self.eps) + self.beta
return y.reshape(x_size).type(x_dtype)
@staticmethod
def fp32(model): # Use this if you're using mixed precision.
for m in model.modules():
if isinstance(m, PN):
m.float()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment