Skip to content

Instantly share code, notes, and snippets.

@dschaehi
Last active April 3, 2019 15:48
Show Gist options
  • Save dschaehi/bf15796cc375345db5e50e357ada32fd to your computer and use it in GitHub Desktop.
Save dschaehi/bf15796cc375345db5e50e357ada32fd to your computer and use it in GitHub Desktop.
Forces tensors to have norms within a range
class MaxNorm(object):
def __init__(self, max_value=1, frequency=5):
self.frequency = frequency
self.max_value = max_value
self.tiny = _finfo(torch.FloatTensor([])).tiny
def __call__(self, module):
if hasattr(module, "weight"):
w = module.weight.data
norms = w.norm(p=2, dim=w.dim() - 1, keepdim=True)
desired = norms.clamp(0, self.max_value)
w *= desired / (self.tiny + norms)
return w
if hasattr(module, "bias"):
b = module.bias.data
norms = b.norm(p=2, dim=b.dim() - 1, keepdim=True)
desired = norms.clamp(0, self.max_value)
b *= desired / (self.tiny + norms)
return b
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment