Last active
March 31, 2018 21:44
-
-
Save santisy/5c3b8e15f13c1c1719fabfd105c970df to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
from torch.nn import Parameter | |
def _l2normalize(v, eps=1e-12): | |
return v / (torch.norm(v, p=2) + eps) | |
def max_singular_value(W, u=None, Ip=1): | |
""" | |
Apply power iteration for weight parameter | |
""" | |
W = W.view(W.size(0), -1) | |
size = W.size() # n x m | |
if u is None: | |
u = Parameter(torch.FloatTensor(1, size[0]).normal_(), | |
requires_grad=False) # 1 x n | |
u = u.cuda() | |
_u = u | |
for _ in range(Ip): | |
_v = _l2normalize(torch.mm(_u, W)) # 1 x m | |
_u = _l2normalize(torch.mm(W, _v.t())) # n x 1 | |
_u = _u.view(1, -1) | |
sigma = _u.mm(W).mm(_v.t()) | |
return sigma, _u | |
class SNLinear(nn.Linear): | |
def __init__(self, in_features, out_features, bias=True, Ip=1): | |
super(SNLinear, self).__init__(in_features, out_features, bias) | |
self.Ip = Ip | |
self.u = None | |
@property | |
def W_bar(self): | |
sigma, _u = max_singular_value(self.weight, u=self.u, Ip=self.Ip) | |
self.u = _u | |
return self.weight / sigma | |
def forward(self, input): | |
return F.linear(input, self.W_bar, self.bias) | |
if __name__ == '__main__': | |
input_var = Variable(torch.randn(3, 10).cuda()) | |
linear1 = SNLinear(10, 20) | |
linear1 = linear1.cuda() | |
output1 = linear1(input_var) | |
output1.sum().backward() | |
linear1.zero_grad() | |
output2 = linear1(input_var) | |
# REPORT BUG HERE: | |
# Trying to backward through the graph a second time, | |
# but the buffers have already been freed. | |
# Specify retain_graph=True when calling backward the first time. | |
######################################### | |
output2.sum().backward() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment