Skip to content

Instantly share code, notes, and snippets.

@elistevens
Created May 31, 2017 16:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save elistevens/5383d51c6c3b3f756ce3b312ef53f3a8 to your computer and use it in GitHub Desktop.
Save elistevens/5383d51c6c3b3f756ce3b312ef53f3a8 to your computer and use it in GitHub Desktop.
Cosine Normalization for PyTorch https://arxiv.org/pdf/1702.05870v2.pdf
import torch
from torch import nn as nn
from torch.autograd import Variable
try:
from util.logconf import logging
except:
import logging
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)
class CosNorm(nn.Module):
"""
https://arxiv.org/pdf/1702.05870v2.pdf
"""
def __init__(self, module, eps=1e-5):
super().__init__()
self.module = module
self.eps = eps
def forward(self, x):
out = self.module(x)
# We don't actually care about out_norm's content; just need right size and type
out_norm = out.data.clone()
w_norm = float(self.module.weight.data.norm()) + self.eps
for sample_ndx in range(x.data.size(0)):
try:
x_norm = float(x.data[sample_ndx].norm()) + self.eps
out_norm[sample_ndx] = 1.0 / (x_norm * w_norm + self.eps)
except:
# log.error([x.size(), out.size(), sample_ndx])
raise
return out * Variable(out_norm)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment