Skip to content

Instantly share code, notes, and snippets.

@AtomicVar
Last active February 15, 2024 10:46
Show Gist options
  • Save AtomicVar/a53c58daea948dbaaf2a3292433a2566 to your computer and use it in GitHub Desktop.
Save AtomicVar/a53c58daea948dbaaf2a3292433a2566 to your computer and use it in GitHub Desktop.
使用 Batch Normalization 实现 Layer Normalization

使用 Batch Normalization 实现 Layer Normalization

Batch Normalization 和 Layer Normalization 是深度学习中常用的两种归一化方法,都可以用以下公式表示:

$$ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta $$

其主要区别就在于 $\mathrm{E}[x]$$\mathrm{Var}[x]$ 是怎么计算的。

  • Batch Normalization 是在整个 mini-batch 上计算 feature vector 的均值和方差。 例如,假设输入是 $(N, C)$,那么 $\mathrm{E}[x]$$\mathrm{Var}[x]$ 的形状就是 $(C,)$;如果输入是 $(N, C, H, W)$,那么 $\mathrm{E}[x]$$\mathrm{Var}[x]$ 的形状仍然是 $(C,)$
  • Layer Normalization 是计算 mini-batch 内每个 Sample 自己的均值和方差。 例如,假设输入是 $(N, C)$,那么 $\mathrm{E}[x]$$\mathrm{Var}[x]$ 的形状就是 $(N,)$

我们可以使用 Batch Normalization 来计算 Layer Normalization:将输入视作 batch size = 1,然后将原 batch size 维度放到新的 feature (channel) 维度上,即:

$$ (N, C) \to (1, N, C) $$

此时,对 $(1, N, C)$ 进行 Batch Normalization,就会得到长度为 $N$$\mathrm{E}[x]$$\mathrm{Var}[x]$,即 Layer Normalization 的做法。

具体代码如下:

from functools import reduce
from operator import mul

def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
    """Layer Normalization implemented by Batch Normalization"""
    # We need to reshape input with shape [*, normalized_shape[0], normalized_shape[1], ..., normalized_shape[-1]]
    # to input with shape [1, n, normalized_shape[0]*normalized_shape[1]*...*normalized_shape[-1]]
    L = reduce(mul, normalized_shape, 1)
    C = input.numel() // L
    input_reshaped = input.view(1, C, L)

    # Do a batch normalization over (N, C, L)
    output = F.batch_norm(input_reshaped, None, None, weight, bias, training=True, eps=eps)
    return output.view(input.shape)

与 PyTorch 内置的 Layer Normalization 进行对比测试:

>>> import torch
>>> import torch.nn.functional as F

>>> x = torch.rand(2, 3, 4)
>>> F.layer_norm(x, [3, 4])
tensor([[[-0.3145,  1.1580, -0.5216, -1.0527],
         [-1.6420,  1.9395,  1.4807, -0.0406],
         [-0.5166, -0.5084,  0.0056,  0.0125]],

        [[-0.3949, -1.0793, -0.2180,  0.8326],
         [ 0.5099,  1.7204, -0.2157,  0.6653],
         [-1.3991,  1.3354, -0.1912, -1.5653]]])

>>> layer_norm(x, [3, 4])
tensor([[[-0.3145,  1.1580, -0.5216, -1.0527],
         [-1.6420,  1.9395,  1.4807, -0.0406],
         [-0.5166, -0.5084,  0.0056,  0.0125]],

        [[-0.3949, -1.0793, -0.2180,  0.8326],
         [ 0.5099,  1.7204, -0.2157,  0.6653],
         [-1.3991,  1.3354, -0.1912, -1.5653]]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment