Skip to content

Instantly share code, notes, and snippets.

@gupta-abhay
Last active March 20, 2023 15:11
Show Gist options
  • Save gupta-abhay/b24bb56837178329e95e32b5c68933fd to your computer and use it in GitHub Desktop.
Save gupta-abhay/b24bb56837178329e95e32b5c68933fd to your computer and use it in GitHub Desktop.
Extending PyTorch LayerNorm to have support for disabling biases during training
import numbers
from typing import List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Size, Tensor
from torch.nn.parameter import Parameter
_shape_t = Union[int, List[int], Size]
class LayerNorm(nn.Module):
__constants__ = [
'normalized_shape',
'eps',
'elementwise_affine',
'bias',
]
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
bias: bool
def __init__(
self,
normalized_shape: _shape_t,
eps: float = 1e-5,
elementwise_affine: bool = True,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {
'device': device,
'dtype': dtype,
}
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
if bias:
self.bias = Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.bias = self.register_parameter('bias', None)
else:
self.weight = self.register_parameter('weight', None)
self.bias = self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, input: Tensor) -> Tensor:
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps
)
def extra_repr(self) -> str:
return (
f'normalized_shape={self.normalized_shape}, eps={self.eps},'
+ f' elementwise_affine={self.elementwise_affine},'
+ f' bias={self.bias is not None}'
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment