神经网络normalization

深度学习中的Normalization,BN/LN/WN

为什么需要 Normalization

  • independent and identically distributed,简称为 i.i.d

并非所有机器学习模型的必然要求(比如 Naive Bayes 模型就建立在特征彼此独立的基础之上,而Logistic Regression 和 神经网络 则在非独立的特征数据上依然可以训练出很好的模型),但独立同分布的数据可以简化常规机器学习模型的训练、提升机器学习模型的预测能力,已经是一个共识。

  • 白化(whitening)

数据预处理步骤。

(1)去除特征之间的相关性 —> 独立;

(2)使得所有特征具有相同的均值和方差 —> 同分布。

  • 深度学习中的 Internal Covariate Shift

参数更新使每一层的数据分布发生变化,向前叠加,高层的受到数据变化的影响,需要不断重新适应底层的数据变化。

  • Internal Covariate Shift,简称 ICS.

    ML经典假设是“源空间(source domain)和目标空间(target domain)的数据分布(distribution)是一致的”

    covariate shift是指源空间和目标空间的条件概率是一致的,但是其边缘概率不同

    ​ 1. 给定输入,拟合label,条件概率一致的

    1. 层间计算导致,各层分布发生改变,边缘概率是不同的
  • ICS的问题

  1. 上层参数需要不断适应新的输入数据分布,降低学习速度
  2. 下层输入的变化可能趋向于变大或者变小,导致上层落入饱和区,使得学习过早停止 (想想sigmoid)
  3. 每层的更新都会影响到其它层,因此每层的参数更新策略需要尽可能的谨慎

Normalization 的通用框架与基本思想

标准的白化操作代价高昂,特别是我们还希望白化操作是可微的(每一点上必存在非垂直切线),保证白化操作可以通过反向传播来更新梯度。

Normalization 方法退而求其次,进行了简化的白化操作。

  • Normalization

先对其做平移和伸缩变换, 将 的分布规范化成在固定区间范围的标准分布。

平移参数(shift parameter), 缩放参数(scale parameter)

再平移参数(re-shift parameter), 再缩放参数(re-scale parameter)

最终得到的数据符合均值为 、方差为 的分布

变换为均值为 、方差为 的分布,也并不是严格的同分布,只是映射到了一个确定的区间范围而已

  • 再平移调整的意义
  1. 不会过分改变每一层计算结果
  2. 第一步的规范化会将几乎所有数据映射到激活函数的非饱和区(线性区),仅利用到了线性变化能力,从而降低了神经网络的表达能力。而进行再变换,则可以将数据从线性区变换到非线性区,恢复模型的表达能力(想想激活函数)

主流 Normalization 方法梳理

  1. Batch Normalization —— 纵向规范化:整个batch的不同维度(channel)
image

其中 是 mini-batch 的大小。由于 BN 是针对单个维度定义的,因此标准公式中的计算均为 element-wise。

然后,用一个 mini-batch 的一阶统计量和二阶统计量,规范每一个输入维度

KEYPOINT:mini-batch数据决定,x每个维度的分布,上图可理解为RGB三个通道。

要求:每个 mini-batch 比较大,数据分布比较接近,充分的 shuffle

不适用:动态的网络结构 和 RNN 网络 (最后才知道mini-batch的\(\mu\))。Batch Normalization基于一个mini batch的数据计算均值和方差,而不是基于整个Training set来做,相当于进行梯度计算式引入噪声。因此,Batch Normalization不适用于对噪声敏感的强化学习、生成模型(Generative model:GAN,VAE)使用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def batch_normalization_layer(inputs, out_size, isTrain=True):
# in_size, out_size = inputs.get_shape()
pop_mean = tf.Variable(tf.zeros([out_size]),trainable=False)
pop_var = tf.Variable(tf.ones([out_size]),trainable=False)
scale = tf.Variable(tf.ones([out_size]))
shift = tf.Variable(tf.zeros([out_size]))
eps = 0.001
decay = 0.999
if isTrain:
# batch的mean和var。 注原始维度为[batch_size, height, width, channel]
batch_mean, batch_var = tf.nn.moments(inputs,[0,1,2])
print(batch_mean.get_shape())
# 记录训练的mean和var
train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1-decay))
train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1-decay))
with tf.control_dependencies([train_mean,train_var]):
return tf.nn.batch_normalization(inputs,batch_mean,batch_var,shift,scale,eps)
else:
return tf.nn.batch_normalization(inputs,pop_mean,pop_var,shift,scale,eps)

ref Synchronized-BatchNorm-PyTorch,多GPU分布式同步各个节点的数据mean和var

  1. Layer Normalization —— 横向规范化:单个输入 https://arxiv.org/abs/1607.06450
image

考虑一层所有维度的输入,计算该层的平均输入值和输入方差,然后用同一个规范化操作来转换各个维度的输入

枚举了该层所有的输入神经元。对应到标准公式中,四大参数 , , , 均为标量(BN中是向量),所有输入共享一个规范化变换

KEYPOINT:LN 针对单个训练样本进行,用于 小mini-batch场景、动态网络场景和 RNN,特别是自然语言处理领域。此外,LN 不需要保存 mini-batch 的均值和方差,节省了额外的存储空间

NOTE:如果不同输入特征不属于相似的类别(比如颜色和大小),那么 LN 的处理可能会降低模型的表达能力。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class layer_norm(Function):
@staticmethod
def forward(input, gain=None, bias=None):
# 这里的输入是unroll的,[batch_size, h x w x c],按实例norm
mean = input.mean(-1, keepdim=True)
var = input.var(-1, unbiased=False, keepdim=True)
input_normalized = (input - mean) / torch.sqrt(var + 1e-9)

if gain is not None and bias is not None:
output = input_normalized * gain + bias
elif not (gain is None and bias is None):
raise RuntimeError("gain and bias of LayerNorm should be both None or not None!")
else:
output = input_normalized

return output
...
  1. Weight Normalization —— 参数规范化 https://arxiv.org/abs/1602.07868

将以下方程

理解为: .

  • BN 和 LN 均将规范化应用于输入的特征数据

  • WN将规范化应用于线性变换的权重

用神经元的权重的欧氏范数对输入数据进行 scale。

是神经元的权重的欧氏范数,因此 是单位向量,决定了 的方向;

是标量,决定了 的长度。

KEYPOINT:WN 的规范化不直接使用输入数据的统计量,因此避免了 BN 过于依赖 mini-batch 的不足,以及 LN 每层唯一转换器的限制,同时也可以用于动态网络结构

Weight Normalization对通过标量g和向量v对权重W进行重写,重写向量v是固定的,因此,基于Weight Normalization的Normalization比Batch Normalization引入更少的噪声。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def _weight_norm(v, g):
'v就是weights'
norm = torch.norm(v, 2)
return v * (g * norm)

def norm_except_dim(v, pow, dim):
'计算g: norm_except_dim(weight, 2, dim).data'
if dim is None:
return v.norm()
if dim != 0:
v = v.transpose(0, dim)
output_size = (v.size(0),) + (1,) * (v.dim() - 1)
v = v.contiguous().view(v.size(0), -1).norm(dim=1).view(*output_size)
if dim != 0:
v = v.transpose(0, dim)
return v

class WeightNorm:
...
def compute_weight(self, module):
g = getattr(module, self.name + '_g')
v = getattr(module, self.name + '_v')
return _weight_norm(v, g)
...
  1. Cosine Normalization —— 余弦规范化

其中 的夹角。所有的数据就都是 [-1, 1] 区间。

超简单的变化,直接在wx的上scale,并且不需要再次缩放。

将 点积》》》变为余弦相似度

  1. Instance Norm
image

InstanceNorm等价于当Group Normnum_groups等于num_channel.

  1. Group Norm https://arxiv.org/abs/1803.08494
image

Group Norm中group的数量是1的时候, 是与LayerNorm是等价的

1
2
3
4
5
6
7
8
9
10
11
12
13
def GroupNorm(x, gamma, beta, G, eps=1e−5):
# x: input features with shape [N,C,H,W]
# gamma, beta: scale and offset, with shape [1,C,1,1]
# G: number of groups for GN
N, C, H, W = x.shape
# group划分
x = tf.reshape(x, [N, G, C // G, H, W])
# 按group求mean var
mean, var = tf.nn.moments(x, [2, 3, 4], keepdims=True)

x = (x−mean) / tf.sqrt(var + eps)
x = tf.reshape(x, [N, C, H, W])
return x∗gamma + beta

Normalization 为什么会有效?

  1. 权重伸缩不变性(weight scale invariance)

其中

由于

因此,权重的伸缩变化不会影响反向梯度的 Jacobian 矩阵,因此也就对反向传播没有影响,避免了反向传播时因为权重过大或过小导致的梯度消失或梯度爆炸问题,从而加速了神经网络的训练

  1. 参数正则

由于

因此,下层的权重值越大,\(\lambda\)越大,那么其梯度就越小。这样,参数的变化就越稳定,相当于实现了参数正则化的效果,避免参数的大幅震荡,提高网络的泛化性能。

  1. 数据伸缩不变性(data scale invariance)

当数据 按照常量 进行伸缩时,得到的规范化后的值保持不变,即:

其中

数据伸缩不变性仅对 BN、LN 和 CN 成立。WN 不具有这一性质。很明显。

  1. 数据伸缩不变性可以有效地减少梯度弥散,简化对学习率的选择

某一层神经元 而言,展开可得(以下式子为示意,没写入激活函数)

每一层神经元的输出依赖于底下各层的计算结果。再次回忆activition function的图像

如果没有正则化,当下层输入发生伸缩变化时,经过层层传递,可能会导致数据发生剧烈的膨胀或者弥散,从而也导致了反向计算时的梯度爆炸或梯度弥散。

而言,其输入 永远保持标准的分布,这就使得高层的训练更加简单。从梯度的计算公式来看:

数据的伸缩变化也不会影响到对该层的权重参数更新,使得训练过程更加鲁棒,简化了对学习率的选择。

参考链接:https://zhuanlan.zhihu.com/p/33173246