Skip to content

Instantly share code, notes, and snippets.

@santisy
Last active March 31, 2018 21:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save santisy/5c3b8e15f13c1c1719fabfd105c970df to your computer and use it in GitHub Desktop.
Save santisy/5c3b8e15f13c1c1719fabfd105c970df to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn import Parameter
def _l2normalize(v, eps=1e-12):
return v / (torch.norm(v, p=2) + eps)
def max_singular_value(W, u=None, Ip=1):
"""
Apply power iteration for weight parameter
"""
W = W.view(W.size(0), -1)
size = W.size() # n x m
if u is None:
u = Parameter(torch.FloatTensor(1, size[0]).normal_(),
requires_grad=False) # 1 x n
u = u.cuda()
_u = u
for _ in range(Ip):
_v = _l2normalize(torch.mm(_u, W)) # 1 x m
_u = _l2normalize(torch.mm(W, _v.t())) # n x 1
_u = _u.view(1, -1)
sigma = _u.mm(W).mm(_v.t())
return sigma, _u
class SNLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True, Ip=1):
super(SNLinear, self).__init__(in_features, out_features, bias)
self.Ip = Ip
self.u = None
@property
def W_bar(self):
sigma, _u = max_singular_value(self.weight, u=self.u, Ip=self.Ip)
self.u = _u
return self.weight / sigma
def forward(self, input):
return F.linear(input, self.W_bar, self.bias)
if __name__ == '__main__':
input_var = Variable(torch.randn(3, 10).cuda())
linear1 = SNLinear(10, 20)
linear1 = linear1.cuda()
output1 = linear1(input_var)
output1.sum().backward()
linear1.zero_grad()
output2 = linear1(input_var)
# REPORT BUG HERE:
# Trying to backward through the graph a second time,
# but the buffers have already been freed.
# Specify retain_graph=True when calling backward the first time.
#########################################
output2.sum().backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment