Skip to content

Instantly share code, notes, and snippets.

@AruniRC
Created March 19, 2019 18:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AruniRC/3d9fee002a6e94189ca27f9e8ccb0604 to your computer and use it in GitHub Desktop.
Save AruniRC/3d9fee002a6e94189ca27f9e8ccb0604 to your computer and use it in GitHub Desktop.
def forward(self, x):
x = self.features(x)
[bs, ch, h, w] = x.shape
x = x.view(bs, ch, -1).transpose(2, 1)
# x.register_hook(self.save_grad('x'))
# Gram Matrix NxN for the N input features "x"
K = x.bmm(x.transpose(2, 1))
K = x * x; # < --- IS THIS CORRECT for 1st order features????
alpha = torch.autograd.Variable(torch.ones(bs, h*w, 1)).cuda()
Ci = torch.sum(K, 2, keepdim=True)
mask = Ci < 1e-10
mask = mask.detach()
Ci = torch.pow(Ci, self.gamma)
Ci[mask] = 0
Ci = Ci.detach()
# Sinkhorn iterations
for _ in range(10):
alpha = torch.pow(alpha + 1e-10, 1-self.sinkhorn_t) / \
(torch.pow(K.bmm(alpha) + 1e-10, self.sinkhorn_t) + 1e-10)
# x = x * torch.pow(alpha, 0.5)
# x = x.transpose(1, 2).bmm(x).view(bs, -1) # EDIT THIS OUT FOR FIRST ORDER ????
x = x * alpha
x = torch.sqrt(x + 1e-8)
x = torch.nn.functional.normalize(x)
x = self.fc(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment