-
-
Save elistevens/5383d51c6c3b3f756ce3b312ef53f3a8 to your computer and use it in GitHub Desktop.
Cosine Normalization for PyTorch https://arxiv.org/pdf/1702.05870v2.pdf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn as nn | |
from torch.autograd import Variable | |
try: | |
from util.logconf import logging | |
except: | |
import logging | |
log = logging.getLogger(__name__) | |
# log.setLevel(logging.WARN) | |
# log.setLevel(logging.INFO) | |
log.setLevel(logging.DEBUG) | |
class CosNorm(nn.Module): | |
""" | |
https://arxiv.org/pdf/1702.05870v2.pdf | |
""" | |
def __init__(self, module, eps=1e-5): | |
super().__init__() | |
self.module = module | |
self.eps = eps | |
def forward(self, x): | |
out = self.module(x) | |
# We don't actually care about out_norm's content; just need right size and type | |
out_norm = out.data.clone() | |
w_norm = float(self.module.weight.data.norm()) + self.eps | |
for sample_ndx in range(x.data.size(0)): | |
try: | |
x_norm = float(x.data[sample_ndx].norm()) + self.eps | |
out_norm[sample_ndx] = 1.0 / (x_norm * w_norm + self.eps) | |
except: | |
# log.error([x.size(), out.size(), sample_ndx]) | |
raise | |
return out * Variable(out_norm) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment